2626
2727import mypy .checkexpr
2828from mypy import errorcodes as codes , message_registry , nodes , operators
29- from mypy .binder import ConditionalTypeBinder , get_declaration
29+ from mypy .binder import ConditionalTypeBinder , Frame , get_declaration
3030from mypy .checkmember import (
3131 MemberContext ,
3232 analyze_decorator_or_funcbase_access ,
4141from mypy .errors import Errors , ErrorWatcher , report_internal_error
4242from mypy .expandtype import expand_self_type , expand_type , expand_type_by_instance
4343from mypy .join import join_types
44- from mypy .literals import Key , literal , literal_hash
44+ from mypy .literals import Key , extract_var_from_literal_hash , literal , literal_hash
4545from mypy .maptype import map_instance_to_supertype
4646from mypy .meet import is_overlapping_erased_types , is_overlapping_types
4747from mypy .message_registry import ErrorMessage
134134 is_final_node ,
135135)
136136from mypy .options import Options
137+ from mypy .patterns import AsPattern , StarredPattern
137138from mypy .plugin import CheckerPluginInterface , Plugin
138139from mypy .scope import Scope
139140from mypy .semanal import is_trivial_body , refers_to_fullname , set_callable_name
151152 restrict_subtype_away ,
152153 unify_generic_callable ,
153154)
154- from mypy .traverser import all_return_statements , has_return_statement
155+ from mypy .traverser import TraverserVisitor , all_return_statements , has_return_statement
155156from mypy .treetransform import TransformVisitor
156157from mypy .typeanal import check_for_explicit_any , has_any_from_unimported_type , make_optional_type
157158from mypy .typeops import (
@@ -1207,6 +1208,20 @@ def check_func_def(
12071208
12081209 # Type check body in a new scope.
12091210 with self .binder .top_frame_context ():
1211+ # Copy some type narrowings from an outer function when it seems safe enough
1212+ # (i.e. we can't find an assignment that might change the type of the
1213+ # variable afterwards).
1214+ new_frame : Frame | None = None
1215+ for frame in old_binder .frames :
1216+ for key , narrowed_type in frame .types .items ():
1217+ key_var = extract_var_from_literal_hash (key )
1218+ if key_var is not None and not self .is_var_redefined_in_outer_context (
1219+ key_var , defn .line
1220+ ):
1221+ # It seems safe to propagate the type narrowing to a nested scope.
1222+ if new_frame is None :
1223+ new_frame = self .binder .push_frame ()
1224+ new_frame .types [key ] = narrowed_type
12101225 with self .scope .push_function (defn ):
12111226 # We suppress reachability warnings when we use TypeVars with value
12121227 # restrictions: we only want to report a warning if a certain statement is
@@ -1218,6 +1233,8 @@ def check_func_def(
12181233 self .binder .suppress_unreachable_warnings ()
12191234 self .accept (item .body )
12201235 unreachable = self .binder .is_unreachable ()
1236+ if new_frame is not None :
1237+ self .binder .pop_frame (True , 0 )
12211238
12221239 if not unreachable :
12231240 if defn .is_generator or is_named_instance (
@@ -1310,6 +1327,23 @@ def check_func_def(
13101327
13111328 self .binder = old_binder
13121329
1330+ def is_var_redefined_in_outer_context (self , v : Var , after_line : int ) -> bool :
1331+ """Can the variable be assigned to at module top level or outer function?
1332+
1333+ Note that this doesn't do a full CFG analysis but uses a line number based
1334+ heuristic that isn't correct in some (rare) cases.
1335+ """
1336+ outers = self .tscope .outer_functions ()
1337+ if not outers :
1338+ # Top-level function -- outer context is top level, and we can't reason about
1339+ # globals
1340+ return True
1341+ for outer in outers :
1342+ if isinstance (outer , FuncDef ):
1343+ if find_last_var_assignment_line (outer .body , v ) >= after_line :
1344+ return True
1345+ return False
1346+
13131347 def check_unbound_return_typevar (self , typ : CallableType ) -> None :
13141348 """Fails when the return typevar is not defined in arguments."""
13151349 if isinstance (typ .ret_type , TypeVarType ) and typ .ret_type in typ .variables :
@@ -7629,3 +7663,80 @@ def collapse_walrus(e: Expression) -> Expression:
76297663 if isinstance (e , AssignmentExpr ):
76307664 return e .target
76317665 return e
7666+
7667+
7668+ def find_last_var_assignment_line (n : Node , v : Var ) -> int :
7669+ """Find the highest line number of a potential assignment to variable within node.
7670+
7671+ This supports local and global variables.
7672+
7673+ Return -1 if no assignment was found.
7674+ """
7675+ visitor = VarAssignVisitor (v )
7676+ n .accept (visitor )
7677+ return visitor .last_line
7678+
7679+
7680+ class VarAssignVisitor (TraverserVisitor ):
7681+ def __init__ (self , v : Var ) -> None :
7682+ self .last_line = - 1
7683+ self .lvalue = False
7684+ self .var_node = v
7685+
7686+ def visit_assignment_stmt (self , s : AssignmentStmt ) -> None :
7687+ self .lvalue = True
7688+ for lv in s .lvalues :
7689+ lv .accept (self )
7690+ self .lvalue = False
7691+
7692+ def visit_name_expr (self , e : NameExpr ) -> None :
7693+ if self .lvalue and e .node is self .var_node :
7694+ self .last_line = max (self .last_line , e .line )
7695+
7696+ def visit_member_expr (self , e : MemberExpr ) -> None :
7697+ old_lvalue = self .lvalue
7698+ self .lvalue = False
7699+ super ().visit_member_expr (e )
7700+ self .lvalue = old_lvalue
7701+
7702+ def visit_index_expr (self , e : IndexExpr ) -> None :
7703+ old_lvalue = self .lvalue
7704+ self .lvalue = False
7705+ super ().visit_index_expr (e )
7706+ self .lvalue = old_lvalue
7707+
7708+ def visit_with_stmt (self , s : WithStmt ) -> None :
7709+ self .lvalue = True
7710+ for lv in s .target :
7711+ if lv is not None :
7712+ lv .accept (self )
7713+ self .lvalue = False
7714+ s .body .accept (self )
7715+
7716+ def visit_for_stmt (self , s : ForStmt ) -> None :
7717+ self .lvalue = True
7718+ s .index .accept (self )
7719+ self .lvalue = False
7720+ s .body .accept (self )
7721+ if s .else_body :
7722+ s .else_body .accept (self )
7723+
7724+ def visit_assignment_expr (self , e : AssignmentExpr ) -> None :
7725+ self .lvalue = True
7726+ e .target .accept (self )
7727+ self .lvalue = False
7728+ e .value .accept (self )
7729+
7730+ def visit_as_pattern (self , p : AsPattern ) -> None :
7731+ if p .pattern is not None :
7732+ p .pattern .accept (self )
7733+ if p .name is not None :
7734+ self .lvalue = True
7735+ p .name .accept (self )
7736+ self .lvalue = False
7737+
7738+ def visit_starred_pattern (self , p : StarredPattern ) -> None :
7739+ if p .capture is not None :
7740+ self .lvalue = True
7741+ p .capture .accept (self )
7742+ self .lvalue = False
0 commit comments