@@ -89,6 +89,391 @@ def check_fastmath(pkg_dir, pkg_name):
8989 return
9090
9191
92+ class FunctionCallVisitor (ast .NodeVisitor ):
93+ """
94+ A class to traverse the AST of the modules of a package to collect
95+ the call stacks of njit functions.
96+
97+ Parameters
98+ ----------
99+ pkg_dir : str
100+ The path to the package directory containing some .py files.
101+
102+ pkg_name : str
103+ The name of the package.
104+
105+ Attributes
106+ ----------
107+ module_names : list
108+ A list of module names to track the modules as the visitor traverses them.
109+
110+ call_stack : list
111+ A list of njit functions, representing a chain of function calls,
112+ where each element is a string of the form "module_name.func_name".
113+
114+ out : list
115+ A list of unique `call_stack`s.
116+
117+ njit_funcs : list
118+ A list of all njit functions in `pkg_dir`'s modules. Each element is a tuple
119+ of the form `(module_name, func_name)`.
120+
121+ njit_modules : set
122+ A set that contains the names of all modules, each of which contains at least
123+ one njit function.
124+
125+ njit_nodes : dict
126+ A dictionary mapping njit function names to their corresponding AST nodes.
127+ A key is a string, and it is of the form "module_name.func_name", and its
128+ corresponding value is the AST node- with type ast.FunctionDef- of that
129+ function.
130+
131+ ast_modules : dict
132+ A dictionary mapping module names to their corresponding AST objects. A key
133+ is the name of a module, and its corresponding value is the content of that
134+ module as an AST object.
135+
136+ Methods
137+ -------
138+ push_module(module_name)
139+ Push the name of a module onto the stack `module_names`.
140+
141+ pop_module()
142+ Pop the last module name from the stack `module_names`.
143+
144+ push_call_stack(module_name, func_name)
145+ Push a function call onto the stack of function calls, `call_stack`.
146+
147+ pop_call_stack()
148+ Pop the last function call from the stack of function calls, `call_stack`
149+
150+ goto_deeper_func(node)
151+ Calls the visit method from class `ast.NodeVisitor` on all children of
152+ the `node`.
153+
154+ goto_next_func(node)
155+ Calls the visit method from class `ast.NodeVisitor` on all children of
156+ the `node`.
157+
158+ push_out()
159+ Push the current function call stack, `call_stack`, onto the output list, `out`,
160+ unless it is already included in one of the so-far-collected call stacks.
161+
162+ visit_Call(node)
163+ This method is called when the visitor encounters a function call in the AST. It
164+ checks if the called function is a njit function and, if so, traverses its AST
165+ to collect its call stack.
166+ """
167+
168+ def __init__ (self , pkg_dir , pkg_name ):
169+ """
170+ Initialize the FunctionCallVisitor class. This method sets up the necessary
171+ attributes and prepares the visitor for traversing the AST of STUMPY's modules.
172+
173+ Parameters
174+ ----------
175+ pkg_dir : str
176+ The path to the package directory containing some .py files.
177+
178+ pkg_name : str
179+ The name of the package.
180+
181+ Returns
182+ -------
183+ None
184+ """
185+ super ().__init__ ()
186+ self .module_names = []
187+ self .call_stack = []
188+ self .out = []
189+
190+ # Setup lists, dicts, and ast objects
191+ self .njit_funcs = get_njit_funcs (pkg_dir )
192+ self .njit_modules = set (mod_name for mod_name , func_name in self .njit_funcs )
193+ self .njit_nodes = {}
194+ self .ast_modules = {}
195+
196+ filepaths = sorted (f for f in pathlib .Path (pkg_dir ).iterdir () if f .is_file ())
197+ ignore = ["__init__.py" , "__pycache__" ]
198+
199+ for filepath in filepaths :
200+ file_name = filepath .name
201+ if (
202+ file_name not in ignore
203+ and not file_name .startswith ("gpu" )
204+ and str (filepath ).endswith (".py" )
205+ ):
206+ module_name = file_name .replace (".py" , "" )
207+ file_contents = ""
208+ with open (filepath , encoding = "utf8" ) as f :
209+ file_contents = f .read ()
210+ self .ast_modules [module_name ] = ast .parse (file_contents )
211+
212+ for node in self .ast_modules [module_name ].body :
213+ if isinstance (node , ast .FunctionDef ):
214+ func_name = node .name
215+ if (module_name , func_name ) in self .njit_funcs :
216+ self .njit_nodes [f"{ module_name } .{ func_name } " ] = node
217+
218+ def push_module (self , module_name ):
219+ """
220+ Push a module name onto the stack of module names.
221+
222+ Parameters
223+ ----------
224+ module_name : str
225+ The name of the module to be pushed onto the stack.
226+
227+ Returns
228+ -------
229+ None
230+ """
231+ self .module_names .append (module_name )
232+
233+ return
234+
235+ def pop_module (self ):
236+ """
237+ Pop the last module name from the stack of module names.
238+
239+ Parameters
240+ ----------
241+ None
242+
243+ Returns
244+ -------
245+ None
246+ """
247+ if self .module_names :
248+ self .module_names .pop ()
249+
250+ return
251+
252+ def push_call_stack (self , module_name , func_name ):
253+ """
254+ Push a function call onto the stack of function calls.
255+
256+ Parameters
257+ ----------
258+ module_name : str
259+ A module's name
260+
261+ func_name : str
262+ A function's name
263+
264+ Returns
265+ -------
266+ None
267+ """
268+ self .call_stack .append (f"{ module_name } .{ func_name } " )
269+
270+ return
271+
272+ def pop_call_stack (self ):
273+ """
274+ Pop the last function call from the stack of function calls.
275+
276+ Parameters
277+ ----------
278+ None
279+
280+ Returns
281+ -------
282+ None
283+ """
284+ if self .call_stack :
285+ self .call_stack .pop ()
286+
287+ return
288+
289+ def goto_deeper_func (self , node ):
290+ """
291+ Calls the visit method from class `ast.NodeVisitor` on
292+ all children of the `node`.
293+
294+ Parameters
295+ ----------
296+ node : ast.AST
297+ The AST node to be visited.
298+
299+ Returns
300+ -------
301+ None
302+ """
303+ self .generic_visit (node )
304+
305+ return
306+
307+ def goto_next_func (self , node ):
308+ """
309+ Calls the visit method from class `ast.NodeVisitor` on
310+ all children of the node.
311+
312+ Parameters
313+ ----------
314+ node : ast.AST
315+ The AST node to be visited.
316+
317+ Returns
318+ -------
319+ None
320+ """
321+ self .generic_visit (node )
322+
323+ return
324+
325+ def push_out (self ):
326+ """
327+ Push the current function call stack onto the output list unless it
328+ is already included in one of the so-far-collected call stacks.
329+
330+
331+ Parameters
332+ ----------
333+ None
334+
335+ Returns
336+ -------
337+ None
338+ """
339+ unique = True
340+ for cs in self .out :
341+ if " " .join (self .call_stack ) in " " .join (cs ):
342+ unique = False
343+ break
344+
345+ if unique :
346+ self .out .append (self .call_stack .copy ())
347+
348+ return
349+
350+ def visit_Call (self , node ):
351+ """
352+ Called when visiting an AST node of type `ast.Call`.
353+
354+ Parameters
355+ ----------
356+ node : ast.Call
357+ The AST node representing a function call.
358+
359+ Returns
360+ -------
361+ None
362+ """
363+ callee_name = ast .unparse (node .func )
364+
365+ module_changed = False
366+ if "." in callee_name :
367+ new_module_name , new_func_name = callee_name .split ("." )[:2 ]
368+
369+ if new_module_name in self .njit_modules :
370+ self .push_module (new_module_name )
371+ module_changed = True
372+ else :
373+ if self .module_names :
374+ new_module_name = self .module_names [- 1 ]
375+ new_func_name = callee_name
376+ callee_name = f"{ new_module_name } .{ new_func_name } "
377+
378+ if callee_name in self .njit_nodes .keys ():
379+ callee_node = self .njit_nodes [callee_name ]
380+ self .push_call_stack (new_module_name , new_func_name )
381+ self .goto_deeper_func (callee_node )
382+ self .push_out ()
383+ self .pop_call_stack ()
384+ if module_changed :
385+ self .pop_module ()
386+
387+ self .goto_next_func (node )
388+
389+ return
390+
391+
392+ def get_njit_call_stacks (pkg_dir , pkg_name ):
393+ """
394+ Get the call stacks of all njit functions in `pkg_dir`
395+
396+ Parameters
397+ ----------
398+ pkg_dir : str
399+ The path to the package directory containing some .py files
400+
401+ pkg_name : str
402+ The name of the package
403+
404+ Returns
405+ -------
406+ out : list
407+ A list of unique function call stacks. Each item is of type list,
408+ representing a chain of function calls.
409+ """
410+ visitor = FunctionCallVisitor (pkg_dir , pkg_name )
411+
412+ for module_name in visitor .njit_modules :
413+ visitor .push_module (module_name )
414+
415+ for node in visitor .ast_modules [module_name ].body :
416+ if isinstance (node , ast .FunctionDef ):
417+ func_name = node .name
418+ if (module_name , func_name ) in visitor .njit_funcs :
419+ visitor .push_call_stack (module_name , func_name )
420+ visitor .visit (node )
421+ visitor .pop_call_stack ()
422+
423+ visitor .pop_module ()
424+
425+ return visitor .out
426+
427+
428+ def check_call_stack_fastmath (pkg_dir , pkg_name ):
429+ """
430+ Check if all njit functions in a call stack have the same `fastmath` flag.
431+ This function raises a ValueError if it finds any inconsistencies in the
432+ `fastmath` flags in at lease one call stack of njit functions.
433+
434+ Parameters
435+ ----------
436+ pkg_dir : str
437+ The path to the directory containing some .py files
438+
439+ pkg_name : str
440+ The name of the package
441+
442+ Returns
443+ -------
444+ None
445+ """
446+ # List of call stacks with inconsistent fastmath flags
447+ inconsistent_call_stacks = []
448+
449+ njit_call_stacks = get_njit_call_stacks (pkg_dir , pkg_name )
450+ for cs in njit_call_stacks :
451+ # Set the fastmath flag of the first function in the call stack
452+ # as the reference flag
453+ module_name , func_name = cs [0 ].split ("." )
454+ module = importlib .import_module (f".{ module_name } " , package = "stumpy" )
455+ func = getattr (module , func_name )
456+ flag_ref = func .targetoptions ["fastmath" ]
457+
458+ for item in cs [1 :]:
459+ module_name , func_name = cs [0 ].split ("." )
460+ module = importlib .import_module (f".{ module_name } " , package = "stumpy" )
461+ func = getattr (module , func_name )
462+ flag = func .targetoptions ["fastmath" ]
463+ if flag != flag_ref :
464+ inconsistent_call_stacks .append (cs )
465+ break
466+
467+ if len (inconsistent_call_stacks ) > 0 :
468+ msg = (
469+ "Found at least one call stack that has inconsistent `fastmath` flags. "
470+ + f"Those call stacks are:\n { inconsistent_call_stacks } \n "
471+ )
472+ raise ValueError (msg )
473+
474+ return
475+
476+
92477if __name__ == "__main__" :
93478 parser = argparse .ArgumentParser ()
94479 parser .add_argument ("--check" , dest = "pkg_dir" )
@@ -98,3 +483,4 @@ def check_fastmath(pkg_dir, pkg_name):
98483 pkg_dir = pathlib .Path (args .pkg_dir )
99484 pkg_name = pkg_dir .name
100485 check_fastmath (str (pkg_dir ), pkg_name )
486+ check_call_stack_fastmath (str (pkg_dir ), pkg_name )
0 commit comments