Skip to content

Commit b6d57ee

Browse files
authored
Merge pull request #737 from codeflash-ai/fix-overlapping-args-codeflash-wrap1
fix overlappings args in codeflash wrap
2 parents 9ac5d34 + 161f8cf commit b6d57ee

File tree

5 files changed

+106
-97
lines changed

5 files changed

+106
-97
lines changed

codeflash/code_utils/deduplicate_code.py

Lines changed: 45 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,39 @@
11
import ast
22
import hashlib
3-
from typing import Dict, Set
43

54

65
class VariableNormalizer(ast.NodeTransformer):
76
"""Normalizes only local variable names in AST to canonical forms like var_0, var_1, etc.
7+
88
Preserves function names, class names, parameters, built-ins, and imported names.
99
"""
1010

11-
def __init__(self):
11+
def __init__(self) -> None:
1212
self.var_counter = 0
13-
self.var_mapping: Dict[str, str] = {}
13+
self.var_mapping: dict[str, str] = {}
1414
self.scope_stack = []
1515
self.builtins = set(dir(__builtins__))
16-
self.imports: Set[str] = set()
17-
self.global_vars: Set[str] = set()
18-
self.nonlocal_vars: Set[str] = set()
19-
self.parameters: Set[str] = set() # Track function parameters
16+
self.imports: set[str] = set()
17+
self.global_vars: set[str] = set()
18+
self.nonlocal_vars: set[str] = set()
19+
self.parameters: set[str] = set() # Track function parameters
2020

21-
def enter_scope(self):
22-
"""Enter a new scope (function/class)"""
21+
def enter_scope(self): # noqa : ANN201
22+
"""Enter a new scope (function/class)."""
2323
self.scope_stack.append(
2424
{"var_mapping": dict(self.var_mapping), "var_counter": self.var_counter, "parameters": set(self.parameters)}
2525
)
2626

27-
def exit_scope(self):
28-
"""Exit current scope and restore parent scope"""
27+
def exit_scope(self): # noqa : ANN201
28+
"""Exit current scope and restore parent scope."""
2929
if self.scope_stack:
3030
scope = self.scope_stack.pop()
3131
self.var_mapping = scope["var_mapping"]
3232
self.var_counter = scope["var_counter"]
3333
self.parameters = scope["parameters"]
3434

3535
def get_normalized_name(self, name: str) -> str:
36-
"""Get or create normalized name for a variable"""
36+
"""Get or create normalized name for a variable."""
3737
# Don't normalize if it's a builtin, import, global, nonlocal, or parameter
3838
if (
3939
name in self.builtins
@@ -50,34 +50,34 @@ def get_normalized_name(self, name: str) -> str:
5050
self.var_counter += 1
5151
return self.var_mapping[name]
5252

53-
def visit_Import(self, node):
54-
"""Track imported names"""
53+
def visit_Import(self, node): # noqa : ANN001, ANN201
54+
"""Track imported names."""
5555
for alias in node.names:
5656
name = alias.asname if alias.asname else alias.name
5757
self.imports.add(name.split(".")[0])
5858
return node
5959

60-
def visit_ImportFrom(self, node):
61-
"""Track imported names from modules"""
60+
def visit_ImportFrom(self, node): # noqa : ANN001, ANN201
61+
"""Track imported names from modules."""
6262
for alias in node.names:
6363
name = alias.asname if alias.asname else alias.name
6464
self.imports.add(name)
6565
return node
6666

67-
def visit_Global(self, node):
68-
"""Track global variable declarations"""
67+
def visit_Global(self, node): # noqa : ANN001, ANN201
68+
"""Track global variable declarations."""
6969
# Avoid repeated .add calls by using set.update with list
7070
self.global_vars.update(node.names)
7171
return node
7272

73-
def visit_Nonlocal(self, node):
74-
"""Track nonlocal variable declarations"""
75-
for name in node.names:
76-
self.nonlocal_vars.add(name)
73+
def visit_Nonlocal(self, node): # noqa : ANN001, ANN201
74+
"""Track nonlocal variable declarations."""
75+
# Using set.update for batch insertion (faster than add-in-loop)
76+
self.nonlocal_vars.update(node.names)
7777
return node
7878

79-
def visit_FunctionDef(self, node):
80-
"""Process function but keep function name and parameters unchanged"""
79+
def visit_FunctionDef(self, node): # noqa : ANN001, ANN201
80+
"""Process function but keep function name and parameters unchanged."""
8181
self.enter_scope()
8282

8383
# Track all parameters (don't modify them)
@@ -95,19 +95,19 @@ def visit_FunctionDef(self, node):
9595
self.exit_scope()
9696
return node
9797

98-
def visit_AsyncFunctionDef(self, node):
99-
"""Handle async functions same as regular functions"""
98+
def visit_AsyncFunctionDef(self, node): # noqa : ANN001, ANN201
99+
"""Handle async functions same as regular functions."""
100100
return self.visit_FunctionDef(node)
101101

102-
def visit_ClassDef(self, node):
103-
"""Process class but keep class name unchanged"""
102+
def visit_ClassDef(self, node): # noqa : ANN001, ANN201
103+
"""Process class but keep class name unchanged."""
104104
self.enter_scope()
105105
node = self.generic_visit(node)
106106
self.exit_scope()
107107
return node
108108

109-
def visit_Name(self, node):
110-
"""Normalize variable names in Name nodes"""
109+
def visit_Name(self, node): # noqa : ANN001, ANN201
110+
"""Normalize variable names in Name nodes."""
111111
if isinstance(node.ctx, (ast.Store, ast.Del)):
112112
# For assignments and deletions, check if we should normalize
113113
if (
@@ -118,20 +118,20 @@ def visit_Name(self, node):
118118
and node.id not in self.nonlocal_vars
119119
):
120120
node.id = self.get_normalized_name(node.id)
121-
elif isinstance(node.ctx, ast.Load):
121+
elif isinstance(node.ctx, ast.Load): # noqa : SIM102
122122
# For loading, use existing mapping if available
123123
if node.id in self.var_mapping:
124124
node.id = self.var_mapping[node.id]
125125
return node
126126

127-
def visit_ExceptHandler(self, node):
128-
"""Normalize exception variable names"""
127+
def visit_ExceptHandler(self, node): # noqa : ANN001, ANN201
128+
"""Normalize exception variable names."""
129129
if node.name:
130130
node.name = self.get_normalized_name(node.name)
131131
return self.generic_visit(node)
132132

133-
def visit_comprehension(self, node):
134-
"""Normalize comprehension target variables"""
133+
def visit_comprehension(self, node): # noqa : ANN001, ANN201
134+
"""Normalize comprehension target variables."""
135135
# Create new scope for comprehension
136136
old_mapping = dict(self.var_mapping)
137137
old_counter = self.var_counter
@@ -144,23 +144,25 @@ def visit_comprehension(self, node):
144144
self.var_counter = old_counter
145145
return node
146146

147-
def visit_For(self, node):
148-
"""Handle for loop target variables"""
147+
def visit_For(self, node): # noqa : ANN001, ANN201
148+
"""Handle for loop target variables."""
149149
# The target in a for loop is a local variable that should be normalized
150150
return self.generic_visit(node)
151151

152-
def visit_With(self, node):
153-
"""Handle with statement as variables"""
152+
def visit_With(self, node): # noqa : ANN001, ANN201
153+
"""Handle with statement as variables."""
154154
return self.generic_visit(node)
155155

156156

157-
def normalize_code(code: str, remove_docstrings: bool = True, return_ast_dump: bool = False) -> str:
157+
def normalize_code(code: str, remove_docstrings: bool = True, return_ast_dump: bool = False) -> str: # noqa : FBT002, FBT001
158158
"""Normalize Python code by parsing, cleaning, and normalizing only variable names.
159+
159160
Function names, class names, and parameters are preserved.
160161
161162
Args:
162163
code: Python source code as string
163164
remove_docstrings: Whether to remove docstrings
165+
return_ast_dump: return_ast_dump
164166
165167
Returns:
166168
Normalized code as string
@@ -191,7 +193,7 @@ def normalize_code(code: str, remove_docstrings: bool = True, return_ast_dump: b
191193
raise ValueError(msg) from e
192194

193195

194-
def remove_docstrings_from_ast(node):
196+
def remove_docstrings_from_ast(node): # noqa : ANN001, ANN201
195197
"""Remove docstrings from AST nodes."""
196198
# Only FunctionDef, AsyncFunctionDef, ClassDef, and Module can contain docstrings in their body[0]
197199
node_types = (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef, ast.Module)
@@ -242,6 +244,7 @@ def are_codes_duplicate(code1: str, code2: str) -> bool:
242244
try:
243245
normalized1 = normalize_code(code1, return_ast_dump=True)
244246
normalized2 = normalize_code(code2, return_ast_dump=True)
245-
return normalized1 == normalized2
246247
except Exception:
247248
return False
249+
else:
250+
return normalized1 == normalized2

codeflash/code_utils/instrument_existing_tests.py

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -365,15 +365,15 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
365365
targets=[ast.Name(id="test_id", ctx=ast.Store())],
366366
value=ast.JoinedStr(
367367
values=[
368-
ast.FormattedValue(value=ast.Name(id="test_module_name", ctx=ast.Load()), conversion=-1),
368+
ast.FormattedValue(value=ast.Name(id="codeflash_test_module_name", ctx=ast.Load()), conversion=-1),
369369
ast.Constant(value=":"),
370-
ast.FormattedValue(value=ast.Name(id="test_class_name", ctx=ast.Load()), conversion=-1),
370+
ast.FormattedValue(value=ast.Name(id="codeflash_test_class_name", ctx=ast.Load()), conversion=-1),
371371
ast.Constant(value=":"),
372-
ast.FormattedValue(value=ast.Name(id="test_name", ctx=ast.Load()), conversion=-1),
372+
ast.FormattedValue(value=ast.Name(id="codeflash_test_name", ctx=ast.Load()), conversion=-1),
373373
ast.Constant(value=":"),
374-
ast.FormattedValue(value=ast.Name(id="line_id", ctx=ast.Load()), conversion=-1),
374+
ast.FormattedValue(value=ast.Name(id="codeflash_line_id", ctx=ast.Load()), conversion=-1),
375375
ast.Constant(value=":"),
376-
ast.FormattedValue(value=ast.Name(id="loop_index", ctx=ast.Load()), conversion=-1),
376+
ast.FormattedValue(value=ast.Name(id="codeflash_loop_index", ctx=ast.Load()), conversion=-1),
377377
]
378378
),
379379
lineno=lineno + 1,
@@ -453,7 +453,7 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
453453
targets=[ast.Name(id="invocation_id", ctx=ast.Store())],
454454
value=ast.JoinedStr(
455455
values=[
456-
ast.FormattedValue(value=ast.Name(id="line_id", ctx=ast.Load()), conversion=-1),
456+
ast.FormattedValue(value=ast.Name(id="codeflash_line_id", ctx=ast.Load()), conversion=-1),
457457
ast.Constant(value="_"),
458458
ast.FormattedValue(value=ast.Name(id="codeflash_test_index", ctx=ast.Load()), conversion=-1),
459459
]
@@ -466,25 +466,31 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
466466
targets=[ast.Name(id="test_stdout_tag", ctx=ast.Store())],
467467
value=ast.JoinedStr(
468468
values=[
469-
ast.FormattedValue(value=ast.Name(id="test_module_name", ctx=ast.Load()), conversion=-1),
469+
ast.FormattedValue(
470+
value=ast.Name(id="codeflash_test_module_name", ctx=ast.Load()), conversion=-1
471+
),
470472
ast.Constant(value=":"),
471473
ast.FormattedValue(
472474
value=ast.IfExp(
473-
test=ast.Name(id="test_class_name", ctx=ast.Load()),
475+
test=ast.Name(id="codeflash_test_class_name", ctx=ast.Load()),
474476
body=ast.BinOp(
475-
left=ast.Name(id="test_class_name", ctx=ast.Load()),
477+
left=ast.Name(id="codeflash_test_class_name", ctx=ast.Load()),
476478
op=ast.Add(),
477479
right=ast.Constant(value="."),
478480
),
479481
orelse=ast.Constant(value=""),
480482
),
481483
conversion=-1,
482484
),
483-
ast.FormattedValue(value=ast.Name(id="test_name", ctx=ast.Load()), conversion=-1),
485+
ast.FormattedValue(value=ast.Name(id="codeflash_test_name", ctx=ast.Load()), conversion=-1),
484486
ast.Constant(value=":"),
485-
ast.FormattedValue(value=ast.Name(id="function_name", ctx=ast.Load()), conversion=-1),
487+
ast.FormattedValue(
488+
value=ast.Name(id="codeflash_function_name", ctx=ast.Load()), conversion=-1
489+
),
486490
ast.Constant(value=":"),
487-
ast.FormattedValue(value=ast.Name(id="loop_index", ctx=ast.Load()), conversion=-1),
491+
ast.FormattedValue(
492+
value=ast.Name(id="codeflash_loop_index", ctx=ast.Load()), conversion=-1
493+
),
488494
ast.Constant(value=":"),
489495
ast.FormattedValue(value=ast.Name(id="invocation_id", ctx=ast.Load()), conversion=-1),
490496
]
@@ -537,7 +543,7 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
537543
ast.Assign(
538544
targets=[ast.Name(id="return_value", ctx=ast.Store())],
539545
value=ast.Call(
540-
func=ast.Name(id="wrapped", ctx=ast.Load()),
546+
func=ast.Name(id="codeflash_wrapped", ctx=ast.Load()),
541547
args=[ast.Starred(value=ast.Name(id="args", ctx=ast.Load()), ctx=ast.Load())],
542548
keywords=[ast.keyword(arg=None, value=ast.Name(id="kwargs", ctx=ast.Load()))],
543549
),
@@ -664,11 +670,11 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
664670
ast.Constant(value="INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)"),
665671
ast.Tuple(
666672
elts=[
667-
ast.Name(id="test_module_name", ctx=ast.Load()),
668-
ast.Name(id="test_class_name", ctx=ast.Load()),
669-
ast.Name(id="test_name", ctx=ast.Load()),
670-
ast.Name(id="function_name", ctx=ast.Load()),
671-
ast.Name(id="loop_index", ctx=ast.Load()),
673+
ast.Name(id="codeflash_test_module_name", ctx=ast.Load()),
674+
ast.Name(id="codeflash_test_class_name", ctx=ast.Load()),
675+
ast.Name(id="codeflash_test_name", ctx=ast.Load()),
676+
ast.Name(id="codeflash_function_name", ctx=ast.Load()),
677+
ast.Name(id="codeflash_loop_index", ctx=ast.Load()),
672678
ast.Name(id="invocation_id", ctx=ast.Load()),
673679
ast.Name(id="codeflash_duration", ctx=ast.Load()),
674680
ast.Name(id="pickled_return_value", ctx=ast.Load()),
@@ -707,13 +713,13 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
707713
name="codeflash_wrap",
708714
args=ast.arguments(
709715
args=[
710-
ast.arg(arg="wrapped", annotation=None),
711-
ast.arg(arg="test_module_name", annotation=None),
712-
ast.arg(arg="test_class_name", annotation=None),
713-
ast.arg(arg="test_name", annotation=None),
714-
ast.arg(arg="function_name", annotation=None),
715-
ast.arg(arg="line_id", annotation=None),
716-
ast.arg(arg="loop_index", annotation=None),
716+
ast.arg(arg="codeflash_wrapped", annotation=None),
717+
ast.arg(arg="codeflash_test_module_name", annotation=None),
718+
ast.arg(arg="codeflash_test_class_name", annotation=None),
719+
ast.arg(arg="codeflash_test_name", annotation=None),
720+
ast.arg(arg="codeflash_function_name", annotation=None),
721+
ast.arg(arg="codeflash_line_id", annotation=None),
722+
ast.arg(arg="codeflash_loop_index", annotation=None),
717723
*([ast.arg(arg="codeflash_cur", annotation=None)] if mode == TestingMode.BEHAVIOR else []),
718724
*([ast.arg(arg="codeflash_con", annotation=None)] if mode == TestingMode.BEHAVIOR else []),
719725
],

codeflash/models/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -558,7 +558,7 @@ def unique_invocation_loop_id(self) -> str:
558558
return f"{self.loop_index}:{self.id.id()}"
559559

560560

561-
class TestResults(BaseModel):
561+
class TestResults(BaseModel): # noqa: PLW1641
562562
# don't modify these directly, use the add method
563563
# also we don't support deletion of test results elements - caution is advised
564564
test_results: list[FunctionTestInvocation] = []

tests/test_instrument_all_and_run.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,31 +15,31 @@
1515
from codeflash.verification.instrument_codeflash_capture import instrument_codeflash_capture
1616

1717
# Used by cli instrumentation
18-
codeflash_wrap_string = """def codeflash_wrap(wrapped, test_module_name, test_class_name, test_name, function_name, line_id, loop_index, codeflash_cur, codeflash_con, *args, **kwargs):
19-
test_id = f'{{test_module_name}}:{{test_class_name}}:{{test_name}}:{{line_id}}:{{loop_index}}'
18+
codeflash_wrap_string = """def codeflash_wrap(codeflash_wrapped, codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_line_id, codeflash_loop_index, codeflash_cur, codeflash_con, *args, **kwargs):
19+
test_id = f'{{codeflash_test_module_name}}:{{codeflash_test_class_name}}:{{codeflash_test_name}}:{{codeflash_line_id}}:{{codeflash_loop_index}}'
2020
if not hasattr(codeflash_wrap, 'index'):
2121
codeflash_wrap.index = {{}}
2222
if test_id in codeflash_wrap.index:
2323
codeflash_wrap.index[test_id] += 1
2424
else:
2525
codeflash_wrap.index[test_id] = 0
2626
codeflash_test_index = codeflash_wrap.index[test_id]
27-
invocation_id = f'{{line_id}}_{{codeflash_test_index}}'
28-
test_stdout_tag = f"{{test_module_name}}:{{(test_class_name + '.' if test_class_name else '')}}{{test_name}}:{{function_name}}:{{loop_index}}:{{invocation_id}}"
27+
invocation_id = f'{{codeflash_line_id}}_{{codeflash_test_index}}'
28+
test_stdout_tag = f"{{codeflash_test_module_name}}:{{(codeflash_test_class_name + '.' if codeflash_test_class_name else '')}}{{codeflash_test_name}}:{{codeflash_function_name}}:{{codeflash_loop_index}}:{{invocation_id}}"
2929
print(f"!$######{{test_stdout_tag}}######$!")
3030
exception = None
3131
gc.disable()
3232
try:
3333
counter = time.perf_counter_ns()
34-
return_value = wrapped(*args, **kwargs)
34+
return_value = codeflash_wrapped(*args, **kwargs)
3535
codeflash_duration = time.perf_counter_ns() - counter
3636
except Exception as e:
3737
codeflash_duration = time.perf_counter_ns() - counter
3838
exception = e
3939
gc.enable()
4040
print(f"!######{{test_stdout_tag}}######!")
4141
pickled_return_value = pickle.dumps(exception) if exception else pickle.dumps(return_value)
42-
codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)', (test_module_name, test_class_name, test_name, function_name, loop_index, invocation_id, codeflash_duration, pickled_return_value, 'function_call'))
42+
codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)', (codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_loop_index, invocation_id, codeflash_duration, pickled_return_value, 'function_call'))
4343
codeflash_con.commit()
4444
if exception:
4545
raise exception

0 commit comments

Comments
 (0)