Skip to content

Commit 95a38d3

Browse files
committed
deduplicate optimizations better
Signed-off-by: Saurabh Misra <misra.saurabh1@gmail.com>
1 parent 2802ae6 commit 95a38d3

File tree

3 files changed

+372
-1
lines changed

3 files changed

+372
-1
lines changed
Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
import ast
2+
import hashlib
3+
from typing import Dict, Set
4+
5+
6+
class VariableNormalizer(ast.NodeTransformer):
7+
"""Normalizes only local variable names in AST to canonical forms like var_0, var_1, etc.
8+
Preserves function names, class names, parameters, built-ins, and imported names.
9+
"""
10+
11+
def __init__(self):
12+
self.var_counter = 0
13+
self.var_mapping: Dict[str, str] = {}
14+
self.scope_stack = []
15+
self.builtins = set(dir(__builtins__))
16+
self.imports: Set[str] = set()
17+
self.global_vars: Set[str] = set()
18+
self.nonlocal_vars: Set[str] = set()
19+
self.parameters: Set[str] = set() # Track function parameters
20+
21+
def enter_scope(self):
22+
"""Enter a new scope (function/class)"""
23+
self.scope_stack.append(
24+
{"var_mapping": dict(self.var_mapping), "var_counter": self.var_counter, "parameters": set(self.parameters)}
25+
)
26+
27+
def exit_scope(self):
28+
"""Exit current scope and restore parent scope"""
29+
if self.scope_stack:
30+
scope = self.scope_stack.pop()
31+
self.var_mapping = scope["var_mapping"]
32+
self.var_counter = scope["var_counter"]
33+
self.parameters = scope["parameters"]
34+
35+
def get_normalized_name(self, name: str) -> str:
36+
"""Get or create normalized name for a variable"""
37+
# Don't normalize if it's a builtin, import, global, nonlocal, or parameter
38+
if (
39+
name in self.builtins
40+
or name in self.imports
41+
or name in self.global_vars
42+
or name in self.nonlocal_vars
43+
or name in self.parameters
44+
):
45+
return name
46+
47+
# Only normalize local variables
48+
if name not in self.var_mapping:
49+
self.var_mapping[name] = f"var_{self.var_counter}"
50+
self.var_counter += 1
51+
return self.var_mapping[name]
52+
53+
def visit_Import(self, node):
54+
"""Track imported names"""
55+
for alias in node.names:
56+
name = alias.asname if alias.asname else alias.name
57+
self.imports.add(name.split(".")[0])
58+
return node
59+
60+
def visit_ImportFrom(self, node):
61+
"""Track imported names from modules"""
62+
for alias in node.names:
63+
name = alias.asname if alias.asname else alias.name
64+
self.imports.add(name)
65+
return node
66+
67+
def visit_Global(self, node):
68+
"""Track global variable declarations"""
69+
for name in node.names:
70+
self.global_vars.add(name)
71+
return node
72+
73+
def visit_Nonlocal(self, node):
74+
"""Track nonlocal variable declarations"""
75+
for name in node.names:
76+
self.nonlocal_vars.add(name)
77+
return node
78+
79+
def visit_FunctionDef(self, node):
80+
"""Process function but keep function name and parameters unchanged"""
81+
self.enter_scope()
82+
83+
# Track all parameters (don't modify them)
84+
for arg in node.args.args:
85+
self.parameters.add(arg.arg)
86+
if node.args.vararg:
87+
self.parameters.add(node.args.vararg.arg)
88+
if node.args.kwarg:
89+
self.parameters.add(node.args.kwarg.arg)
90+
for arg in node.args.kwonlyargs:
91+
self.parameters.add(arg.arg)
92+
93+
# Visit function body
94+
node = self.generic_visit(node)
95+
self.exit_scope()
96+
return node
97+
98+
def visit_AsyncFunctionDef(self, node):
99+
"""Handle async functions same as regular functions"""
100+
return self.visit_FunctionDef(node)
101+
102+
def visit_ClassDef(self, node):
103+
"""Process class but keep class name unchanged"""
104+
self.enter_scope()
105+
node = self.generic_visit(node)
106+
self.exit_scope()
107+
return node
108+
109+
def visit_Name(self, node):
110+
"""Normalize variable names in Name nodes"""
111+
if isinstance(node.ctx, (ast.Store, ast.Del)):
112+
# For assignments and deletions, check if we should normalize
113+
if (
114+
node.id not in self.builtins
115+
and node.id not in self.imports
116+
and node.id not in self.parameters
117+
and node.id not in self.global_vars
118+
and node.id not in self.nonlocal_vars
119+
):
120+
node.id = self.get_normalized_name(node.id)
121+
elif isinstance(node.ctx, ast.Load):
122+
# For loading, use existing mapping if available
123+
if node.id in self.var_mapping:
124+
node.id = self.var_mapping[node.id]
125+
return node
126+
127+
def visit_ExceptHandler(self, node):
128+
"""Normalize exception variable names"""
129+
if node.name:
130+
node.name = self.get_normalized_name(node.name)
131+
return self.generic_visit(node)
132+
133+
def visit_comprehension(self, node):
134+
"""Normalize comprehension target variables"""
135+
# Create new scope for comprehension
136+
old_mapping = dict(self.var_mapping)
137+
old_counter = self.var_counter
138+
139+
# Process the comprehension
140+
node = self.generic_visit(node)
141+
142+
# Restore scope
143+
self.var_mapping = old_mapping
144+
self.var_counter = old_counter
145+
return node
146+
147+
def visit_For(self, node):
148+
"""Handle for loop target variables"""
149+
# The target in a for loop is a local variable that should be normalized
150+
return self.generic_visit(node)
151+
152+
def visit_With(self, node):
153+
"""Handle with statement as variables"""
154+
return self.generic_visit(node)
155+
156+
157+
def normalize_code(code: str, remove_docstrings: bool = True) -> str:
158+
"""Normalize Python code by parsing, cleaning, and normalizing only variable names.
159+
Function names, class names, and parameters are preserved.
160+
161+
Args:
162+
code: Python source code as string
163+
remove_docstrings: Whether to remove docstrings
164+
165+
Returns:
166+
Normalized code as string
167+
168+
"""
169+
try:
170+
# Parse the code
171+
tree = ast.parse(code)
172+
173+
# Remove docstrings if requested
174+
if remove_docstrings:
175+
remove_docstrings_from_ast(tree)
176+
177+
# Normalize variable names
178+
normalizer = VariableNormalizer()
179+
normalized_tree = normalizer.visit(tree)
180+
181+
# Fix missing locations in the AST
182+
ast.fix_missing_locations(normalized_tree)
183+
184+
# Unparse back to code
185+
return ast.unparse(normalized_tree)
186+
except SyntaxError as e:
187+
msg = f"Invalid Python syntax: {e}"
188+
raise ValueError(msg) from e
189+
190+
191+
def remove_docstrings_from_ast(node):
192+
"""Remove docstrings from AST nodes."""
193+
# Process all nodes in the tree, but avoid recursion
194+
for current_node in ast.walk(node):
195+
if isinstance(current_node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef, ast.Module)):
196+
if (
197+
current_node.body
198+
and isinstance(current_node.body[0], ast.Expr)
199+
and isinstance(current_node.body[0].value, ast.Constant)
200+
and isinstance(current_node.body[0].value.value, str)
201+
):
202+
current_node.body = current_node.body[1:]
203+
204+
205+
def get_code_fingerprint(code: str) -> str:
206+
"""Generate a fingerprint for normalized code.
207+
208+
Args:
209+
code: Python source code
210+
211+
Returns:
212+
SHA-256 hash of normalized code
213+
214+
"""
215+
normalized = normalize_code(code)
216+
return hashlib.sha256(normalized.encode()).hexdigest()
217+
218+
219+
def are_codes_duplicate(code1: str, code2: str) -> bool:
220+
"""Check if two code segments are duplicates after normalization.
221+
222+
Args:
223+
code1: First code segment
224+
code2: Second code segment
225+
226+
Returns:
227+
True if codes are structurally identical (ignoring local variable names)
228+
229+
"""
230+
try:
231+
normalized1 = normalize_code(code1)
232+
normalized2 = normalize_code(code2)
233+
return normalized1 == normalized2
234+
except Exception:
235+
return False

codeflash/optimization/function_optimizer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
REPEAT_OPTIMIZATION_PROBABILITY,
4949
TOTAL_LOOPING_TIME,
5050
)
51+
from codeflash.code_utils.deduplicate_code import normalize_code
5152
from codeflash.code_utils.edit_generated_tests import (
5253
add_runtime_comments_to_generated_tests,
5354
remove_functions_from_generated_tests,
@@ -519,7 +520,7 @@ def determine_best_candidate(
519520
)
520521
continue
521522
# check if this code has been evaluated before by checking the ast normalized code string
522-
normalized_code = ast.unparse(ast.parse(candidate.source_code.flat.strip()))
523+
normalized_code = normalize_code(candidate.source_code.flat.strip())
523524
if normalized_code in ast_code_to_id:
524525
logger.info(
525526
"Current candidate has been encountered before in testing, Skipping optimization candidate."

tests/test_code_deduplication.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
from codeflash.code_utils.deduplicate_code import are_codes_duplicate, normalize_code
2+
3+
4+
def test_deduplicate1():
5+
# Example usage and tests
6+
# Example 1: Same logic, different variable names (should NOT match due to different function/param names)
7+
code1 = """
8+
def compute_sum(numbers):
9+
'''Calculate sum of numbers'''
10+
total = 0
11+
for num in numbers:
12+
total += num
13+
return total
14+
"""
15+
16+
code2 = """
17+
def compute_sum(numbers):
18+
# This computes the sum
19+
result = 0
20+
for value in numbers:
21+
result += value
22+
return result
23+
"""
24+
25+
assert normalize_code(code1) == normalize_code(code2)
26+
assert are_codes_duplicate(code1, code2)
27+
28+
# Example 3: Same function and parameter names, different local variables (should match)
29+
code3 = """
30+
def calculate_sum(numbers):
31+
accumulator = 0
32+
for item in numbers:
33+
accumulator += item
34+
return accumulator
35+
"""
36+
37+
code4 = """
38+
def calculate_sum(numbers):
39+
total = 0
40+
for num in numbers:
41+
total += num
42+
return total
43+
"""
44+
45+
assert normalize_code(code3) == normalize_code(code4)
46+
assert are_codes_duplicate(code3, code4)
47+
48+
# Example 4: Nested functions and classes (preserving names)
49+
code5 = """
50+
class DataProcessor:
51+
def __init__(self, data):
52+
self.data = data
53+
54+
def process(self):
55+
def helper(item):
56+
temp = item * 2
57+
return temp
58+
59+
results = []
60+
for element in self.data:
61+
results.append(helper(element))
62+
return results
63+
"""
64+
65+
code6 = """
66+
class DataProcessor:
67+
def __init__(self, data):
68+
self.data = data
69+
70+
def process(self):
71+
def helper(item):
72+
x = item * 2
73+
return x
74+
75+
output = []
76+
for thing in self.data:
77+
output.append(helper(thing))
78+
return output
79+
"""
80+
81+
assert normalize_code(code5) == normalize_code(code6)
82+
83+
# Example 5: With imports and built-ins (these should be preserved)
84+
code7 = """
85+
import math
86+
87+
def calculate_circle_area(radius):
88+
pi_value = math.pi
89+
area = pi_value * radius ** 2
90+
return area
91+
"""
92+
93+
code8 = """
94+
import math
95+
96+
def calculate_circle_area(radius):
97+
constant = math.pi
98+
result = constant * radius ** 2
99+
return result
100+
"""
101+
code85 = """
102+
import math
103+
104+
def calculate_circle_area(radius):
105+
constant = math.pi
106+
result = constant *2 * radius ** 2
107+
return result
108+
"""
109+
110+
assert normalize_code(code7) == normalize_code(code8)
111+
assert normalize_code(code8) != normalize_code(code85)
112+
113+
# Example 6: Exception handling
114+
code9 = """
115+
def safe_divide(a, b):
116+
try:
117+
result = a / b
118+
return result
119+
except ZeroDivisionError as e:
120+
error_msg = str(e)
121+
return None
122+
"""
123+
124+
code10 = """
125+
def safe_divide(a, b):
126+
try:
127+
output = a / b
128+
return output
129+
except ZeroDivisionError as exc:
130+
message = str(exc)
131+
return None
132+
"""
133+
assert normalize_code(code9) == normalize_code(code10)
134+
135+
assert normalize_code(code9) != normalize_code(code8)

0 commit comments

Comments
 (0)