|
2 | 2 |
|
3 | 3 | import ast |
4 | 4 | import platform |
| 5 | +from dataclasses import dataclass |
5 | 6 | from pathlib import Path |
6 | 7 | from typing import TYPE_CHECKING |
7 | 8 |
|
|
20 | 21 | from codeflash.models.models import CodePosition |
21 | 22 |
|
22 | 23 |
|
| 24 | +@dataclass(frozen=True) |
| 25 | +class FunctionCallNodeArguments: |
| 26 | + args: list[ast.expr] |
| 27 | + keywords: list[ast.keyword] |
| 28 | + |
| 29 | + |
| 30 | +def get_call_arguments(call_node: ast.Call) -> FunctionCallNodeArguments: |
| 31 | + return FunctionCallNodeArguments(call_node.args, call_node.keywords) |
| 32 | + |
| 33 | + |
23 | 34 | def node_in_call_position(node: ast.AST, call_positions: list[CodePosition]) -> bool: |
24 | | - if isinstance(node, ast.Call) and hasattr(node, "lineno") and hasattr(node, "col_offset"): |
25 | | - for pos in call_positions: |
26 | | - if ( |
27 | | - pos.line_no is not None |
28 | | - and node.end_lineno is not None |
29 | | - and node.lineno <= pos.line_no <= node.end_lineno |
30 | | - ): |
31 | | - if pos.line_no == node.lineno and node.col_offset <= pos.col_no: |
32 | | - return True |
33 | | - if ( |
34 | | - pos.line_no == node.end_lineno |
35 | | - and node.end_col_offset is not None |
36 | | - and node.end_col_offset >= pos.col_no |
37 | | - ): |
38 | | - return True |
39 | | - if node.lineno < pos.line_no < node.end_lineno: |
40 | | - return True |
| 35 | + # Profile: The most meaningful speedup here is to reduce attribute lookup and to localize call_positions if not empty. |
| 36 | + # Small optimizations for tight loop: |
| 37 | + if isinstance(node, ast.Call): |
| 38 | + node_lineno = getattr(node, "lineno", None) |
| 39 | + node_col_offset = getattr(node, "col_offset", None) |
| 40 | + node_end_lineno = getattr(node, "end_lineno", None) |
| 41 | + node_end_col_offset = getattr(node, "end_col_offset", None) |
| 42 | + if node_lineno is not None and node_col_offset is not None and node_end_lineno is not None: |
| 43 | + # Faster loop: reduce attribute lookups, use local variables for conditionals. |
| 44 | + for pos in call_positions: |
| 45 | + pos_line = pos.line_no |
| 46 | + if pos_line is not None and node_lineno <= pos_line <= node_end_lineno: |
| 47 | + if pos_line == node_lineno and node_col_offset <= pos.col_no: |
| 48 | + return True |
| 49 | + if ( |
| 50 | + pos_line == node_end_lineno |
| 51 | + and node_end_col_offset is not None |
| 52 | + and node_end_col_offset >= pos.col_no |
| 53 | + ): |
| 54 | + return True |
| 55 | + if node_lineno < pos_line < node_end_lineno: |
| 56 | + return True |
41 | 57 | return False |
42 | 58 |
|
43 | 59 |
|
@@ -73,66 +89,235 @@ def __init__( |
73 | 89 | def find_and_update_line_node( |
74 | 90 | self, test_node: ast.stmt, node_name: str, index: str, test_class_name: str | None = None |
75 | 91 | ) -> Iterable[ast.stmt] | None: |
| 92 | + # Major optimization: since ast.walk is *very* expensive for big trees and only checks for ast.Call, |
| 93 | + # it's much more efficient to visit nodes manually. We'll only descend into expressions/statements. |
| 94 | + |
| 95 | + # Helper for manual walk |
| 96 | + def iter_ast_calls(node): # noqa: ANN202, ANN001 |
| 97 | + # Generator to yield each ast.Call in test_node, preserves node identity |
| 98 | + stack = [node] |
| 99 | + while stack: |
| 100 | + n = stack.pop() |
| 101 | + if isinstance(n, ast.Call): |
| 102 | + yield n |
| 103 | + # Instead of using ast.walk (which calls iter_child_nodes under the hood in Python, which copy lists and stack-frames for EVERY node), |
| 104 | + # do a specialized BFS with only the necessary attributes |
| 105 | + for _field, value in ast.iter_fields(n): |
| 106 | + if isinstance(value, list): |
| 107 | + for item in reversed(value): |
| 108 | + if isinstance(item, ast.AST): |
| 109 | + stack.append(item) # noqa: PERF401 |
| 110 | + elif isinstance(value, ast.AST): |
| 111 | + stack.append(value) |
| 112 | + |
| 113 | + # This change improves from O(N) stack-frames per child-node to a single stack, less python call overhead |
| 114 | + return_statement = [test_node] |
76 | 115 | call_node = None |
77 | | - for node in ast.walk(test_node): |
78 | | - if isinstance(node, ast.Call) and node_in_call_position(node, self.call_positions): |
79 | | - call_node = node |
80 | | - if isinstance(node.func, ast.Name): |
81 | | - function_name = node.func.id |
82 | 116 |
|
83 | | - if self.function_object.is_async: |
| 117 | + # Minor optimization: Convert mode, function_name, test_class_name, qualified_name, etc to locals |
| 118 | + fn_obj = self.function_object |
| 119 | + module_path = self.module_path |
| 120 | + mode = self.mode |
| 121 | + qualified_name = fn_obj.qualified_name |
| 122 | + |
| 123 | + # Use locals for all 'current' values, only look up class/function/constant AST object once. |
| 124 | + codeflash_loop_index = ast.Name(id="codeflash_loop_index", ctx=ast.Load()) |
| 125 | + codeflash_cur = ast.Name(id="codeflash_cur", ctx=ast.Load()) |
| 126 | + codeflash_con = ast.Name(id="codeflash_con", ctx=ast.Load()) |
| 127 | + |
| 128 | + for node in iter_ast_calls(test_node): |
| 129 | + if not node_in_call_position(node, self.call_positions): |
| 130 | + continue |
| 131 | + |
| 132 | + call_node = node |
| 133 | + all_args = get_call_arguments(call_node) |
| 134 | + # Two possible call types: Name and Attribute |
| 135 | + node_func = node.func |
| 136 | + |
| 137 | + if isinstance(node_func, ast.Name): |
| 138 | + function_name = node_func.id |
| 139 | + |
| 140 | + # Check if this is the function we want to instrument |
| 141 | + if function_name != fn_obj.function_name: |
| 142 | + continue |
| 143 | + |
| 144 | + if fn_obj.is_async: |
| 145 | + return [test_node] |
| 146 | + |
| 147 | + # Build once, reuse objects. |
| 148 | + inspect_name = ast.Name(id="inspect", ctx=ast.Load()) |
| 149 | + bind_call = ast.Assign( |
| 150 | + targets=[ast.Name(id="_call__bound__arguments", ctx=ast.Store())], |
| 151 | + value=ast.Call( |
| 152 | + func=ast.Attribute( |
| 153 | + value=ast.Call( |
| 154 | + func=ast.Attribute(value=inspect_name, attr="signature", ctx=ast.Load()), |
| 155 | + args=[ast.Name(id=function_name, ctx=ast.Load())], |
| 156 | + keywords=[], |
| 157 | + ), |
| 158 | + attr="bind", |
| 159 | + ctx=ast.Load(), |
| 160 | + ), |
| 161 | + args=all_args.args, |
| 162 | + keywords=all_args.keywords, |
| 163 | + ), |
| 164 | + lineno=test_node.lineno, |
| 165 | + col_offset=test_node.col_offset, |
| 166 | + ) |
| 167 | + |
| 168 | + apply_defaults = ast.Expr( |
| 169 | + value=ast.Call( |
| 170 | + func=ast.Attribute( |
| 171 | + value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()), |
| 172 | + attr="apply_defaults", |
| 173 | + ctx=ast.Load(), |
| 174 | + ), |
| 175 | + args=[], |
| 176 | + keywords=[], |
| 177 | + ), |
| 178 | + lineno=test_node.lineno + 1, |
| 179 | + col_offset=test_node.col_offset, |
| 180 | + ) |
| 181 | + |
| 182 | + node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load()) |
| 183 | + base_args = [ |
| 184 | + ast.Name(id=function_name, ctx=ast.Load()), |
| 185 | + ast.Constant(value=module_path), |
| 186 | + ast.Constant(value=test_class_name or None), |
| 187 | + ast.Constant(value=node_name), |
| 188 | + ast.Constant(value=qualified_name), |
| 189 | + ast.Constant(value=index), |
| 190 | + codeflash_loop_index, |
| 191 | + ] |
| 192 | + # Extend with BEHAVIOR extras if needed |
| 193 | + if mode == TestingMode.BEHAVIOR: |
| 194 | + base_args += [codeflash_cur, codeflash_con] |
| 195 | + # Extend with call args (performance) or starred bound args (behavior) |
| 196 | + if mode == TestingMode.PERFORMANCE: |
| 197 | + base_args += call_node.args |
| 198 | + else: |
| 199 | + base_args.append( |
| 200 | + ast.Starred( |
| 201 | + value=ast.Attribute( |
| 202 | + value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()), |
| 203 | + attr="args", |
| 204 | + ctx=ast.Load(), |
| 205 | + ), |
| 206 | + ctx=ast.Load(), |
| 207 | + ) |
| 208 | + ) |
| 209 | + node.args = base_args |
| 210 | + # Prepare keywords |
| 211 | + if mode == TestingMode.BEHAVIOR: |
| 212 | + node.keywords = [ |
| 213 | + ast.keyword( |
| 214 | + value=ast.Attribute( |
| 215 | + value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()), |
| 216 | + attr="kwargs", |
| 217 | + ctx=ast.Load(), |
| 218 | + ) |
| 219 | + ) |
| 220 | + ] |
| 221 | + else: |
| 222 | + node.keywords = call_node.keywords |
| 223 | + |
| 224 | + return_statement = ( |
| 225 | + [bind_call, apply_defaults, test_node] if mode == TestingMode.BEHAVIOR else [test_node] |
| 226 | + ) |
| 227 | + break |
| 228 | + if isinstance(node_func, ast.Attribute): |
| 229 | + function_to_test = node_func.attr |
| 230 | + if function_to_test == fn_obj.function_name: |
| 231 | + if fn_obj.is_async: |
84 | 232 | return [test_node] |
85 | 233 |
|
| 234 | + # Create the signature binding statements |
| 235 | + |
| 236 | + # Unparse only once |
| 237 | + function_name_expr = ast.parse(ast.unparse(node_func), mode="eval").body |
| 238 | + |
| 239 | + inspect_name = ast.Name(id="inspect", ctx=ast.Load()) |
| 240 | + bind_call = ast.Assign( |
| 241 | + targets=[ast.Name(id="_call__bound__arguments", ctx=ast.Store())], |
| 242 | + value=ast.Call( |
| 243 | + func=ast.Attribute( |
| 244 | + value=ast.Call( |
| 245 | + func=ast.Attribute(value=inspect_name, attr="signature", ctx=ast.Load()), |
| 246 | + args=[function_name_expr], |
| 247 | + keywords=[], |
| 248 | + ), |
| 249 | + attr="bind", |
| 250 | + ctx=ast.Load(), |
| 251 | + ), |
| 252 | + args=all_args.args, |
| 253 | + keywords=all_args.keywords, |
| 254 | + ), |
| 255 | + lineno=test_node.lineno, |
| 256 | + col_offset=test_node.col_offset, |
| 257 | + ) |
| 258 | + |
| 259 | + apply_defaults = ast.Expr( |
| 260 | + value=ast.Call( |
| 261 | + func=ast.Attribute( |
| 262 | + value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()), |
| 263 | + attr="apply_defaults", |
| 264 | + ctx=ast.Load(), |
| 265 | + ), |
| 266 | + args=[], |
| 267 | + keywords=[], |
| 268 | + ), |
| 269 | + lineno=test_node.lineno + 1, |
| 270 | + col_offset=test_node.col_offset, |
| 271 | + ) |
| 272 | + |
86 | 273 | node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load()) |
87 | | - node.args = [ |
88 | | - ast.Name(id=function_name, ctx=ast.Load()), |
89 | | - ast.Constant(value=self.module_path), |
| 274 | + base_args = [ |
| 275 | + function_name_expr, |
| 276 | + ast.Constant(value=module_path), |
90 | 277 | ast.Constant(value=test_class_name or None), |
91 | 278 | ast.Constant(value=node_name), |
92 | | - ast.Constant(value=self.function_object.qualified_name), |
| 279 | + ast.Constant(value=qualified_name), |
93 | 280 | ast.Constant(value=index), |
94 | | - ast.Name(id="codeflash_loop_index", ctx=ast.Load()), |
95 | | - *( |
96 | | - [ast.Name(id="codeflash_cur", ctx=ast.Load()), ast.Name(id="codeflash_con", ctx=ast.Load())] |
97 | | - if self.mode == TestingMode.BEHAVIOR |
98 | | - else [] |
99 | | - ), |
100 | | - *call_node.args, |
| 281 | + codeflash_loop_index, |
101 | 282 | ] |
102 | | - node.keywords = call_node.keywords |
103 | | - break |
104 | | - if isinstance(node.func, ast.Attribute): |
105 | | - function_to_test = node.func.attr |
106 | | - if function_to_test == self.function_object.function_name: |
107 | | - if self.function_object.is_async: |
108 | | - return [test_node] |
109 | | - |
110 | | - function_name = ast.unparse(node.func) |
111 | | - node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load()) |
112 | | - node.args = [ |
113 | | - ast.Name(id=function_name, ctx=ast.Load()), |
114 | | - ast.Constant(value=self.module_path), |
115 | | - ast.Constant(value=test_class_name or None), |
116 | | - ast.Constant(value=node_name), |
117 | | - ast.Constant(value=self.function_object.qualified_name), |
118 | | - ast.Constant(value=index), |
119 | | - ast.Name(id="codeflash_loop_index", ctx=ast.Load()), |
120 | | - *( |
121 | | - [ |
122 | | - ast.Name(id="codeflash_cur", ctx=ast.Load()), |
123 | | - ast.Name(id="codeflash_con", ctx=ast.Load()), |
124 | | - ] |
125 | | - if self.mode == TestingMode.BEHAVIOR |
126 | | - else [] |
127 | | - ), |
128 | | - *call_node.args, |
| 283 | + if mode == TestingMode.BEHAVIOR: |
| 284 | + base_args += [codeflash_cur, codeflash_con] |
| 285 | + if mode == TestingMode.PERFORMANCE: |
| 286 | + base_args += call_node.args |
| 287 | + else: |
| 288 | + base_args.append( |
| 289 | + ast.Starred( |
| 290 | + value=ast.Attribute( |
| 291 | + value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()), |
| 292 | + attr="args", |
| 293 | + ctx=ast.Load(), |
| 294 | + ), |
| 295 | + ctx=ast.Load(), |
| 296 | + ) |
| 297 | + ) |
| 298 | + node.args = base_args |
| 299 | + if mode == TestingMode.BEHAVIOR: |
| 300 | + node.keywords = [ |
| 301 | + ast.keyword( |
| 302 | + value=ast.Attribute( |
| 303 | + value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()), |
| 304 | + attr="kwargs", |
| 305 | + ctx=ast.Load(), |
| 306 | + ) |
| 307 | + ) |
129 | 308 | ] |
| 309 | + else: |
130 | 310 | node.keywords = call_node.keywords |
131 | | - break |
| 311 | + |
| 312 | + # Return the signature binding statements along with the test_node |
| 313 | + return_statement = ( |
| 314 | + [bind_call, apply_defaults, test_node] if mode == TestingMode.BEHAVIOR else [test_node] |
| 315 | + ) |
| 316 | + break |
132 | 317 |
|
133 | 318 | if call_node is None: |
134 | 319 | return None |
135 | | - return [test_node] |
| 320 | + return return_statement |
136 | 321 |
|
137 | 322 | def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef: |
138 | 323 | # TODO: Ensure that this class inherits from unittest.TestCase. Don't modify non unittest.TestCase classes. |
@@ -593,7 +778,11 @@ def inject_profiling_into_existing_test( |
593 | 778 | ] |
594 | 779 | if mode == TestingMode.BEHAVIOR: |
595 | 780 | new_imports.extend( |
596 | | - [ast.Import(names=[ast.alias(name="sqlite3")]), ast.Import(names=[ast.alias(name="dill", asname="pickle")])] |
| 781 | + [ |
| 782 | + ast.Import(names=[ast.alias(name="inspect")]), |
| 783 | + ast.Import(names=[ast.alias(name="sqlite3")]), |
| 784 | + ast.Import(names=[ast.alias(name="dill", asname="pickle")]), |
| 785 | + ] |
597 | 786 | ) |
598 | 787 | if test_framework == "unittest" and platform.system() != "Windows": |
599 | 788 | new_imports.append(ast.Import(names=[ast.alias(name="timeout_decorator")])) |
|
0 commit comments