Skip to content

Commit 653fa9d

Browse files
authored
Merge pull request #867 from codeflash-ai/inspect-signature-issue
Behavior Test Instrumentation to account for input mutation
2 parents dc6c4cd + f6302d0 commit 653fa9d

File tree

5 files changed

+854
-140
lines changed

5 files changed

+854
-140
lines changed

code_to_optimize/bubble_sort_method.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,27 @@ def sorter(self, arr):
1515
arr[j + 1] = temp
1616
print("stderr test", file=sys.stderr)
1717
return arr
18+
19+
@classmethod
20+
def sorter_classmethod(cls, arr):
21+
print("codeflash stdout : BubbleSorter.sorter_classmethod() called")
22+
for i in range(len(arr)):
23+
for j in range(len(arr) - 1):
24+
if arr[j] > arr[j + 1]:
25+
temp = arr[j]
26+
arr[j] = arr[j + 1]
27+
arr[j + 1] = temp
28+
print("stderr test classmethod", file=sys.stderr)
29+
return arr
30+
31+
@staticmethod
32+
def sorter_staticmethod(arr):
33+
print("codeflash stdout : BubbleSorter.sorter_staticmethod() called")
34+
for i in range(len(arr)):
35+
for j in range(len(arr) - 1):
36+
if arr[j] > arr[j + 1]:
37+
temp = arr[j]
38+
arr[j] = arr[j + 1]
39+
arr[j + 1] = temp
40+
print("stderr test staticmethod", file=sys.stderr)
41+
return arr

codeflash/code_utils/instrument_existing_tests.py

Lines changed: 253 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import ast
44
import platform
5+
from dataclasses import dataclass
56
from pathlib import Path
67
from typing import TYPE_CHECKING
78

@@ -20,24 +21,39 @@
2021
from codeflash.models.models import CodePosition
2122

2223

24+
@dataclass(frozen=True)
25+
class FunctionCallNodeArguments:
26+
args: list[ast.expr]
27+
keywords: list[ast.keyword]
28+
29+
30+
def get_call_arguments(call_node: ast.Call) -> FunctionCallNodeArguments:
31+
return FunctionCallNodeArguments(call_node.args, call_node.keywords)
32+
33+
2334
def node_in_call_position(node: ast.AST, call_positions: list[CodePosition]) -> bool:
24-
if isinstance(node, ast.Call) and hasattr(node, "lineno") and hasattr(node, "col_offset"):
25-
for pos in call_positions:
26-
if (
27-
pos.line_no is not None
28-
and node.end_lineno is not None
29-
and node.lineno <= pos.line_no <= node.end_lineno
30-
):
31-
if pos.line_no == node.lineno and node.col_offset <= pos.col_no:
32-
return True
33-
if (
34-
pos.line_no == node.end_lineno
35-
and node.end_col_offset is not None
36-
and node.end_col_offset >= pos.col_no
37-
):
38-
return True
39-
if node.lineno < pos.line_no < node.end_lineno:
40-
return True
35+
# Profile: The most meaningful speedup here is to reduce attribute lookup and to localize call_positions if not empty.
36+
# Small optimizations for tight loop:
37+
if isinstance(node, ast.Call):
38+
node_lineno = getattr(node, "lineno", None)
39+
node_col_offset = getattr(node, "col_offset", None)
40+
node_end_lineno = getattr(node, "end_lineno", None)
41+
node_end_col_offset = getattr(node, "end_col_offset", None)
42+
if node_lineno is not None and node_col_offset is not None and node_end_lineno is not None:
43+
# Faster loop: reduce attribute lookups, use local variables for conditionals.
44+
for pos in call_positions:
45+
pos_line = pos.line_no
46+
if pos_line is not None and node_lineno <= pos_line <= node_end_lineno:
47+
if pos_line == node_lineno and node_col_offset <= pos.col_no:
48+
return True
49+
if (
50+
pos_line == node_end_lineno
51+
and node_end_col_offset is not None
52+
and node_end_col_offset >= pos.col_no
53+
):
54+
return True
55+
if node_lineno < pos_line < node_end_lineno:
56+
return True
4157
return False
4258

4359

@@ -73,66 +89,235 @@ def __init__(
7389
def find_and_update_line_node(
7490
self, test_node: ast.stmt, node_name: str, index: str, test_class_name: str | None = None
7591
) -> Iterable[ast.stmt] | None:
92+
# Major optimization: since ast.walk is *very* expensive for big trees and only checks for ast.Call,
93+
# it's much more efficient to visit nodes manually. We'll only descend into expressions/statements.
94+
95+
# Helper for manual walk
96+
def iter_ast_calls(node): # noqa: ANN202, ANN001
97+
# Generator to yield each ast.Call in test_node, preserves node identity
98+
stack = [node]
99+
while stack:
100+
n = stack.pop()
101+
if isinstance(n, ast.Call):
102+
yield n
103+
# Instead of using ast.walk (which calls iter_child_nodes under the hood in Python, which copy lists and stack-frames for EVERY node),
104+
# do a specialized BFS with only the necessary attributes
105+
for _field, value in ast.iter_fields(n):
106+
if isinstance(value, list):
107+
for item in reversed(value):
108+
if isinstance(item, ast.AST):
109+
stack.append(item) # noqa: PERF401
110+
elif isinstance(value, ast.AST):
111+
stack.append(value)
112+
113+
# This change improves from O(N) stack-frames per child-node to a single stack, less python call overhead
114+
return_statement = [test_node]
76115
call_node = None
77-
for node in ast.walk(test_node):
78-
if isinstance(node, ast.Call) and node_in_call_position(node, self.call_positions):
79-
call_node = node
80-
if isinstance(node.func, ast.Name):
81-
function_name = node.func.id
82116

83-
if self.function_object.is_async:
117+
# Minor optimization: Convert mode, function_name, test_class_name, qualified_name, etc to locals
118+
fn_obj = self.function_object
119+
module_path = self.module_path
120+
mode = self.mode
121+
qualified_name = fn_obj.qualified_name
122+
123+
# Use locals for all 'current' values, only look up class/function/constant AST object once.
124+
codeflash_loop_index = ast.Name(id="codeflash_loop_index", ctx=ast.Load())
125+
codeflash_cur = ast.Name(id="codeflash_cur", ctx=ast.Load())
126+
codeflash_con = ast.Name(id="codeflash_con", ctx=ast.Load())
127+
128+
for node in iter_ast_calls(test_node):
129+
if not node_in_call_position(node, self.call_positions):
130+
continue
131+
132+
call_node = node
133+
all_args = get_call_arguments(call_node)
134+
# Two possible call types: Name and Attribute
135+
node_func = node.func
136+
137+
if isinstance(node_func, ast.Name):
138+
function_name = node_func.id
139+
140+
# Check if this is the function we want to instrument
141+
if function_name != fn_obj.function_name:
142+
continue
143+
144+
if fn_obj.is_async:
145+
return [test_node]
146+
147+
# Build once, reuse objects.
148+
inspect_name = ast.Name(id="inspect", ctx=ast.Load())
149+
bind_call = ast.Assign(
150+
targets=[ast.Name(id="_call__bound__arguments", ctx=ast.Store())],
151+
value=ast.Call(
152+
func=ast.Attribute(
153+
value=ast.Call(
154+
func=ast.Attribute(value=inspect_name, attr="signature", ctx=ast.Load()),
155+
args=[ast.Name(id=function_name, ctx=ast.Load())],
156+
keywords=[],
157+
),
158+
attr="bind",
159+
ctx=ast.Load(),
160+
),
161+
args=all_args.args,
162+
keywords=all_args.keywords,
163+
),
164+
lineno=test_node.lineno,
165+
col_offset=test_node.col_offset,
166+
)
167+
168+
apply_defaults = ast.Expr(
169+
value=ast.Call(
170+
func=ast.Attribute(
171+
value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()),
172+
attr="apply_defaults",
173+
ctx=ast.Load(),
174+
),
175+
args=[],
176+
keywords=[],
177+
),
178+
lineno=test_node.lineno + 1,
179+
col_offset=test_node.col_offset,
180+
)
181+
182+
node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load())
183+
base_args = [
184+
ast.Name(id=function_name, ctx=ast.Load()),
185+
ast.Constant(value=module_path),
186+
ast.Constant(value=test_class_name or None),
187+
ast.Constant(value=node_name),
188+
ast.Constant(value=qualified_name),
189+
ast.Constant(value=index),
190+
codeflash_loop_index,
191+
]
192+
# Extend with BEHAVIOR extras if needed
193+
if mode == TestingMode.BEHAVIOR:
194+
base_args += [codeflash_cur, codeflash_con]
195+
# Extend with call args (performance) or starred bound args (behavior)
196+
if mode == TestingMode.PERFORMANCE:
197+
base_args += call_node.args
198+
else:
199+
base_args.append(
200+
ast.Starred(
201+
value=ast.Attribute(
202+
value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()),
203+
attr="args",
204+
ctx=ast.Load(),
205+
),
206+
ctx=ast.Load(),
207+
)
208+
)
209+
node.args = base_args
210+
# Prepare keywords
211+
if mode == TestingMode.BEHAVIOR:
212+
node.keywords = [
213+
ast.keyword(
214+
value=ast.Attribute(
215+
value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()),
216+
attr="kwargs",
217+
ctx=ast.Load(),
218+
)
219+
)
220+
]
221+
else:
222+
node.keywords = call_node.keywords
223+
224+
return_statement = (
225+
[bind_call, apply_defaults, test_node] if mode == TestingMode.BEHAVIOR else [test_node]
226+
)
227+
break
228+
if isinstance(node_func, ast.Attribute):
229+
function_to_test = node_func.attr
230+
if function_to_test == fn_obj.function_name:
231+
if fn_obj.is_async:
84232
return [test_node]
85233

234+
# Create the signature binding statements
235+
236+
# Unparse only once
237+
function_name_expr = ast.parse(ast.unparse(node_func), mode="eval").body
238+
239+
inspect_name = ast.Name(id="inspect", ctx=ast.Load())
240+
bind_call = ast.Assign(
241+
targets=[ast.Name(id="_call__bound__arguments", ctx=ast.Store())],
242+
value=ast.Call(
243+
func=ast.Attribute(
244+
value=ast.Call(
245+
func=ast.Attribute(value=inspect_name, attr="signature", ctx=ast.Load()),
246+
args=[function_name_expr],
247+
keywords=[],
248+
),
249+
attr="bind",
250+
ctx=ast.Load(),
251+
),
252+
args=all_args.args,
253+
keywords=all_args.keywords,
254+
),
255+
lineno=test_node.lineno,
256+
col_offset=test_node.col_offset,
257+
)
258+
259+
apply_defaults = ast.Expr(
260+
value=ast.Call(
261+
func=ast.Attribute(
262+
value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()),
263+
attr="apply_defaults",
264+
ctx=ast.Load(),
265+
),
266+
args=[],
267+
keywords=[],
268+
),
269+
lineno=test_node.lineno + 1,
270+
col_offset=test_node.col_offset,
271+
)
272+
86273
node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load())
87-
node.args = [
88-
ast.Name(id=function_name, ctx=ast.Load()),
89-
ast.Constant(value=self.module_path),
274+
base_args = [
275+
function_name_expr,
276+
ast.Constant(value=module_path),
90277
ast.Constant(value=test_class_name or None),
91278
ast.Constant(value=node_name),
92-
ast.Constant(value=self.function_object.qualified_name),
279+
ast.Constant(value=qualified_name),
93280
ast.Constant(value=index),
94-
ast.Name(id="codeflash_loop_index", ctx=ast.Load()),
95-
*(
96-
[ast.Name(id="codeflash_cur", ctx=ast.Load()), ast.Name(id="codeflash_con", ctx=ast.Load())]
97-
if self.mode == TestingMode.BEHAVIOR
98-
else []
99-
),
100-
*call_node.args,
281+
codeflash_loop_index,
101282
]
102-
node.keywords = call_node.keywords
103-
break
104-
if isinstance(node.func, ast.Attribute):
105-
function_to_test = node.func.attr
106-
if function_to_test == self.function_object.function_name:
107-
if self.function_object.is_async:
108-
return [test_node]
109-
110-
function_name = ast.unparse(node.func)
111-
node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load())
112-
node.args = [
113-
ast.Name(id=function_name, ctx=ast.Load()),
114-
ast.Constant(value=self.module_path),
115-
ast.Constant(value=test_class_name or None),
116-
ast.Constant(value=node_name),
117-
ast.Constant(value=self.function_object.qualified_name),
118-
ast.Constant(value=index),
119-
ast.Name(id="codeflash_loop_index", ctx=ast.Load()),
120-
*(
121-
[
122-
ast.Name(id="codeflash_cur", ctx=ast.Load()),
123-
ast.Name(id="codeflash_con", ctx=ast.Load()),
124-
]
125-
if self.mode == TestingMode.BEHAVIOR
126-
else []
127-
),
128-
*call_node.args,
283+
if mode == TestingMode.BEHAVIOR:
284+
base_args += [codeflash_cur, codeflash_con]
285+
if mode == TestingMode.PERFORMANCE:
286+
base_args += call_node.args
287+
else:
288+
base_args.append(
289+
ast.Starred(
290+
value=ast.Attribute(
291+
value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()),
292+
attr="args",
293+
ctx=ast.Load(),
294+
),
295+
ctx=ast.Load(),
296+
)
297+
)
298+
node.args = base_args
299+
if mode == TestingMode.BEHAVIOR:
300+
node.keywords = [
301+
ast.keyword(
302+
value=ast.Attribute(
303+
value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()),
304+
attr="kwargs",
305+
ctx=ast.Load(),
306+
)
307+
)
129308
]
309+
else:
130310
node.keywords = call_node.keywords
131-
break
311+
312+
# Return the signature binding statements along with the test_node
313+
return_statement = (
314+
[bind_call, apply_defaults, test_node] if mode == TestingMode.BEHAVIOR else [test_node]
315+
)
316+
break
132317

133318
if call_node is None:
134319
return None
135-
return [test_node]
320+
return return_statement
136321

137322
def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
138323
# TODO: Ensure that this class inherits from unittest.TestCase. Don't modify non unittest.TestCase classes.
@@ -593,7 +778,11 @@ def inject_profiling_into_existing_test(
593778
]
594779
if mode == TestingMode.BEHAVIOR:
595780
new_imports.extend(
596-
[ast.Import(names=[ast.alias(name="sqlite3")]), ast.Import(names=[ast.alias(name="dill", asname="pickle")])]
781+
[
782+
ast.Import(names=[ast.alias(name="inspect")]),
783+
ast.Import(names=[ast.alias(name="sqlite3")]),
784+
ast.Import(names=[ast.alias(name="dill", asname="pickle")]),
785+
]
597786
)
598787
if test_framework == "unittest" and platform.system() != "Windows":
599788
new_imports.append(ast.Import(names=[ast.alias(name="timeout_decorator")]))

0 commit comments

Comments
 (0)