1+ from __future__ import annotations
2+
13import ast
2- from typing import Optional , Union
4+ import sys
35
46from .common import CheckContext , set_current
57
68
79class RulesOfHooksVisitor (ast .NodeVisitor ):
810 def __init__ (self , context : CheckContext ) -> None :
911 self ._context = context
10- self ._current_hook : Optional [ast .FunctionDef ] = None
11- self ._current_component : Optional [ast .FunctionDef ] = None
12- self ._current_function : Optional [ast .FunctionDef ] = None
13- self ._current_call : Optional [ast .Call ] = None
14- self ._current_conditional : Union [None , ast .If , ast .IfExp , ast .Try ] = None
15- self ._current_loop : Union [None , ast .For , ast .While ] = None
12+ self ._current_call : ast .Call | None = None
13+ self ._current_component : ast .FunctionDef | None = None
14+ self ._current_conditional : ast .If | ast .IfExp | ast .Try | None = None
15+ self ._current_early_return : ast .Return | None = None
16+ self ._current_function : ast .FunctionDef | None = None
17+ self ._current_hook : ast .FunctionDef | None = None
18+ self ._current_loop : ast .For | ast .While | None = None
1619
1720 def visit_FunctionDef (self , node : ast .FunctionDef ) -> None :
1821 if self ._context .is_hook_def (node ):
@@ -24,6 +27,7 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
2427 # we need to reset these before enter new hook
2528 conditional = None ,
2629 loop = None ,
30+ early_return = None ,
2731 ):
2832 self .generic_visit (node )
2933 elif self ._context .is_component_def (node ):
@@ -34,13 +38,14 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
3438 # we need to reset these before visiting a new component
3539 conditional = None ,
3640 loop = None ,
41+ early_return = None ,
3742 ):
3843 self .generic_visit (node )
3944 else :
4045 with set_current (self , function = node ):
4146 self .generic_visit (node )
4247
43- def _visit_hook_usage (self , node : Union [ ast .Name , ast .Attribute ] ) -> None :
48+ def _visit_hook_usage (self , node : ast .Name | ast .Attribute ) -> None :
4449 self ._check_if_propper_hook_usage (node )
4550
4651 visit_Attribute = _visit_hook_usage
@@ -53,6 +58,7 @@ def _visit_conditional(self, node: ast.AST) -> None:
5358 visit_If = _visit_conditional
5459 visit_IfExp = _visit_conditional
5560 visit_Try = _visit_conditional
61+ visit_Match = _visit_conditional
5662
5763 def _visit_loop (self , node : ast .AST ) -> None :
5864 with set_current (self , loop = node ):
@@ -61,14 +67,15 @@ def _visit_loop(self, node: ast.AST) -> None:
6167 visit_For = _visit_loop
6268 visit_While = _visit_loop
6369
70+ def visit_Return (self , node : ast .Return ) -> None :
71+ self ._current_early_return = node
72+
6473 def _check_if_hook_defined_in_function (self , node : ast .FunctionDef ) -> None :
6574 if self ._current_function is not None :
6675 msg = f"hook { node .name !r} defined as closure in function { self ._current_function .name !r} "
6776 self ._context .add_error (100 , node , msg )
6877
69- def _check_if_propper_hook_usage (
70- self , node : Union [ast .Name , ast .Attribute ]
71- ) -> None :
78+ def _check_if_propper_hook_usage (self , node : ast .Name | ast .Attribute ) -> None :
7279 if isinstance (node , ast .Name ):
7380 name = node .id
7481 else :
@@ -83,14 +90,24 @@ def _check_if_propper_hook_usage(
8390
8491 loop_or_conditional = self ._current_conditional or self ._current_loop
8592 if loop_or_conditional is not None :
86- node_type = type (loop_or_conditional )
87- node_type_to_name = {
88- ast .If : "if statement" ,
89- ast .IfExp : "inline if expression" ,
90- ast .Try : "try statement" ,
91- ast .For : "for loop" ,
92- ast .While : "while loop" ,
93- }
94- node_name = node_type_to_name [node_type ]
93+ node_name = _NODE_TYPE_TO_NAME [type (loop_or_conditional )]
9594 msg = f"hook { name !r} used inside { node_name } "
9695 self ._context .add_error (102 , node , msg )
96+
97+ if self ._current_early_return :
98+ self ._context .add_error (
99+ 103 ,
100+ node ,
101+ f"hook { name !r} used after an early return" ,
102+ )
103+
104+
105+ _NODE_TYPE_TO_NAME = {
106+ ast .If : "if statement" ,
107+ ast .IfExp : "inline if expression" ,
108+ ast .Try : "try statement" ,
109+ ast .For : "for loop" ,
110+ ast .While : "while loop" ,
111+ }
112+ if sys .version_info >= (3 , 10 ):
113+ _NODE_TYPE_TO_NAME [ast .Match ] = "match statement"
0 commit comments