11import ast
2- from typing import Any , AsyncIterator , Callable , Dict , Iterator , Optional , Type
2+ from abc import ABC
3+ from collections import defaultdict
4+
5+ from robot .parsing .model .statements import Statement
6+ from typing_extensions import Any , AsyncIterator , Callable , Dict , Iterator , Optional , Type , Union
37
48__all__ = ["iter_fields" , "iter_child_nodes" , "AsyncVisitor" ]
59
610
11+ def _patch_robot () -> None :
12+ if hasattr (Statement , "_fields" ):
13+ Statement ._fields = ()
14+
15+
16+ _patch_robot ()
17+
18+
719def iter_fields (node : ast .AST ) -> Iterator [Any ]:
8- """
9- Yield a tuple of ``(fieldname, value)`` for each field in ``node._fields``
10- that is present on *node*.
11- """
1220 for field in node ._fields :
1321 try :
1422 yield field , getattr (node , field )
1523 except AttributeError :
1624 pass
1725
1826
27+ def iter_field_values (node : ast .AST ) -> Iterator [Any ]:
28+ for field in node ._fields :
29+ try :
30+ yield getattr (node , field )
31+ except AttributeError :
32+ pass
33+
34+
1935def iter_child_nodes (node : ast .AST ) -> Iterator [ast .AST ]:
20- """
21- Yield all direct child nodes of *node*, that is, all fields that are nodes
22- and all items of fields that are lists of nodes.
23- """
2436 for _name , field in iter_fields (node ):
2537 if isinstance (field , ast .AST ):
2638 yield field
@@ -46,60 +58,69 @@ async def iter_nodes(node: ast.AST) -> AsyncIterator[ast.AST]:
4658 yield n
4759
4860
49- class VisitorFinder :
50- __NOT_SET = object ()
61+ class _NotSet :
62+ pass
63+
5164
52- def __init__ (self ) -> None :
53- self .__cache : Dict [Type [Any ], Optional [Callable [..., Any ]]] = {}
65+ class VisitorFinder (ABC ):
66+ __NOT_SET = _NotSet ()
67+ __cls_finder_cache__ : Dict [Type [Any ], Union [Callable [..., Any ], None , _NotSet ]]
5468
55- def __find_visitor (self , cls : Type [Any ]) -> Optional [Callable [..., Any ]]:
56- if cls is ast .AST :
69+ def __init_subclass__ (cls , ** kwargs : Any ) -> None :
70+ super ().__init_subclass__ (** kwargs )
71+ cls .__cls_finder_cache__ = defaultdict (lambda : cls .__NOT_SET )
72+
73+ @classmethod
74+ def __find_visitor (cls , node_cls : Type [Any ]) -> Optional [Callable [..., Any ]]:
75+ if node_cls is ast .AST :
5776 return None
58- method_name = "visit_" + cls .__name__
59- if hasattr (self , method_name ):
60- method = getattr (self , method_name )
61- if callable (method ):
62- return method # type: ignore
63- for base in cls .__bases__ :
64- method = self ._find_visitor (base )
77+ method_name = "visit_" + node_cls .__name__
78+ method = getattr (cls , method_name , None )
79+ if callable (method ):
80+ return method # type: ignore[no-any-return]
81+ for base in node_cls .__bases__ :
82+ method = cls ._find_visitor (base )
6583 if method :
66- return method # type: ignore
84+ return method
6785 return None
6886
69- def _find_visitor (self , cls : Type [Any ]) -> Optional [Callable [..., Any ]]:
70- r = self .__cache .get (cls , self .__NOT_SET )
71- if r is self .__NOT_SET :
72- self .__cache [cls ] = r = self .__find_visitor (cls )
73- return r # type: ignore
87+ @classmethod
88+ def _find_visitor (cls , node_cls : Type [Any ]) -> Optional [Callable [..., Any ]]:
89+ result = cls .__cls_finder_cache__ [node_cls ]
90+ if result is cls .__NOT_SET :
91+ result = cls .__cls_finder_cache__ [node_cls ] = cls .__find_visitor (node_cls )
92+ return result # type: ignore[return-value]
7493
7594
7695class AsyncVisitor (VisitorFinder ):
7796 async def visit (self , node : ast .AST ) -> None :
78- visitor = self ._find_visitor (type (node )) or self .generic_visit
79- await visitor (node )
97+ visitor = self ._find_visitor (type (node )) or self .__class__ . generic_visit
98+ await visitor (self , node )
8099
81100 async def generic_visit (self , node : ast .AST ) -> None :
82- """Called if no explicit visitor function exists for a node."""
83- for _ , value in iter_fields (node ):
84- if isinstance (value , list ):
101+ for value in iter_field_values (node ):
102+ if value is None :
103+ continue
104+ if isinstance (value , ast .AST ):
105+ await self .visit (value )
106+ elif isinstance (value , list ):
85107 for item in value :
86108 if isinstance (item , ast .AST ):
87109 await self .visit (item )
88- elif isinstance (value , ast .AST ):
89- await self .visit (value )
90110
91111
92112class Visitor (VisitorFinder ):
93113 def visit (self , node : ast .AST ) -> None :
94- visitor = self ._find_visitor (type (node )) or self .generic_visit
95- visitor (node )
114+ visitor = self ._find_visitor (type (node )) or self .__class__ . generic_visit
115+ visitor (self , node )
96116
97117 def generic_visit (self , node : ast .AST ) -> None :
98- """Called if no explicit visitor function exists for a node."""
99- for field , value in iter_fields (node ):
118+ for value in iter_field_values (node ):
119+ if value is None :
120+ continue
100121 if isinstance (value , list ):
101122 for item in value :
102123 if isinstance (item , ast .AST ):
103124 self .visit (item )
104- elif isinstance ( value , ast . AST ) :
125+ else :
105126 self .visit (value )
0 commit comments