11import ast
2- from typing import Optional , List , Union , Set
2+ from typing import Optional , Union , Set
33
4- from .utils import is_hook_def , is_element_def , ErrorVisitor
4+ from .utils import is_hook_def , is_element_def , ErrorVisitor , set_current
55
66
77HOOKS_WITH_DEPS = ("use_effect" , "use_callback" , "use_memo" )
1010class ExhaustiveDepsVisitor (ErrorVisitor ):
1111 def __init__ (self ) -> None :
1212 super ().__init__ ()
13+ self ._current_function : Optional [ast .FunctionDef ] = None
1314 self ._current_hook_or_element : Optional [ast .FunctionDef ] = None
1415
1516 def visit_FunctionDef (self , node : ast .FunctionDef ) -> None :
1617 if is_hook_def (node ) or is_element_def (node ):
17- self ._current_hook_or_element = node
18- self .generic_visit (node )
19- self ._current_hook_or_element = None
18+ with set_current (self , hook_or_element = node ):
19+ self .generic_visit (node )
2020 elif self ._current_hook_or_element is not None :
2121 for deco in node .decorator_list :
2222 if not isinstance (deco , ast .Call ):
@@ -94,30 +94,35 @@ def _check_hook_dependency_list_is_exhaustive(
9494
9595 func_name = "lambda" if isinstance (func , ast .Lambda ) else func .name
9696
97- visitor = _MissingNameOrAttrVisitor (
98- hook_name ,
99- func_name ,
100- _param_names_of_function_def (func ),
101- dep_names ,
97+ top_level_variable_finder = _TopLevelVariableFinder ()
98+ top_level_variable_finder .visit (self ._current_hook_or_element )
99+ variables_defined_in_scope = top_level_variable_finder .variable_names
100+
101+ missing_name_finder = _MissingNameFinder (
102+ hook_name = hook_name ,
103+ func_name = func_name ,
104+ dep_names = dep_names ,
105+ names_in_scope = variables_defined_in_scope ,
106+ ignore_names = _param_names_of_function_def (func ),
102107 )
103108 if isinstance (func .body , list ):
104109 for b in func .body :
105- visitor .visit (b )
110+ missing_name_finder .visit (b )
106111 else :
107- visitor .visit (func .body )
112+ missing_name_finder .visit (func .body )
108113
109- self .errors .extend (visitor .errors )
114+ self .errors .extend (missing_name_finder .errors )
110115
111116 def _get_dependency_names_from_expression (
112117 self , hook_name : str , dependency_expr : Optional [ast .expr ]
113- ) -> Optional [List [str ]]:
118+ ) -> Optional [Set [str ]]:
114119 if dependency_expr is None :
115- return []
120+ return set ()
116121 elif isinstance (dependency_expr , (ast .List , ast .Tuple )):
117- dep_names : List [str ] = []
122+ dep_names : Set [str ] = set ()
118123 for elt in dependency_expr .elts :
119124 if isinstance (elt , ast .Name ):
120- dep_names .append (elt .id )
125+ dep_names .add (elt .id )
121126 else :
122127 # ideally we could deal with some common use cases, but since React's
123128 # own linter doesn't do this we'll just take the easy route for now:
@@ -144,24 +149,26 @@ def _get_dependency_names_from_expression(
144149 return None
145150
146151
147- class _MissingNameOrAttrVisitor (ErrorVisitor ):
152+ class _MissingNameFinder (ErrorVisitor ):
148153 def __init__ (
149154 self ,
150155 hook_name : str ,
151156 func_name : str ,
152- ignore_names : List [str ],
153- dep_names : List [str ],
157+ dep_names : Set [str ],
158+ ignore_names : Set [str ],
159+ names_in_scope : Set [str ],
154160 ) -> None :
155161 super ().__init__ ()
156162 self ._hook_name = hook_name
157163 self ._func_name = func_name
158164 self ._ignore_names = ignore_names
159165 self ._dep_names = dep_names
166+ self ._names_in_scope = names_in_scope
160167 self .used_deps : Set [str ] = set ()
161168
162169 def visit_Name (self , node : ast .Name ) -> None :
163170 node_id = node .id
164- if node_id not in self ._ignore_names :
171+ if node_id not in self ._ignore_names and node_id in self . _names_in_scope :
165172 if node_id in self ._dep_names :
166173 self .used_deps .add (node_id )
167174 else :
@@ -175,12 +182,33 @@ def visit_Name(self, node: ast.Name) -> None:
175182 )
176183
177184
178- def _param_names_of_function_def (func : Union [ast .FunctionDef , ast .Lambda ]) -> List [str ]:
179- names : List [str ] = []
180- names .extend (a .arg for a in func .args .args )
181- names .extend (kw .arg for kw in func .args .kwonlyargs )
185+ class _TopLevelVariableFinder (ast .NodeVisitor ):
186+ def __init__ (self ) -> None :
187+ self ._scope_entered = False
188+ self ._current_scope_is_top_level = True
189+ self .variable_names : Set [str ] = set ()
190+
191+ def visit_Name (self , node : ast .Name ) -> None :
192+ if isinstance (node .ctx , ast .Store ):
193+ self .variable_names .add (node .id )
194+
195+ def _visit_new_scope (self , node : Union [ast .FunctionDef , ast .ClassDef ]) -> None :
196+ if not self ._scope_entered :
197+ self ._scope_entered = True
198+ self .generic_visit (node )
199+ elif self ._current_scope_is_top_level :
200+ self .variable_names .add (node .name )
201+
202+ visit_FunctionDef = _visit_new_scope
203+ visit_ClassDef = _visit_new_scope
204+
205+
206+ def _param_names_of_function_def (func : Union [ast .FunctionDef , ast .Lambda ]) -> Set [str ]:
207+ names : Set [str ] = set ()
208+ names .update (a .arg for a in func .args .args )
209+ names .update (kw .arg for kw in func .args .kwonlyargs )
182210 if func .args .vararg is not None :
183- names .append (func .args .vararg .arg )
211+ names .add (func .args .vararg .arg )
184212 if func .args .kwarg is not None :
185- names .append (func .args .kwarg .arg )
213+ names .add (func .args .kwarg .arg )
186214 return names
0 commit comments