Skip to content

Commit 630ca8a

Browse files
Merge pull request #898 from codeflash-ai/fix/handle-new-added-classes
[FIX] Handle new added classes
2 parents 827cf36 + 1871a87 commit 630ca8a

File tree

3 files changed

+197
-45
lines changed

3 files changed

+197
-45
lines changed

codeflash/code_utils/code_extractor.py

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,33 @@ def visit_Assign(self, node: cst.Assign) -> Optional[bool]:
7272
return True
7373

7474

75+
def find_insertion_index_after_imports(node: cst.Module) -> int:
76+
"""Find the position of the last import statement in the top-level of the module."""
77+
insert_index = 0
78+
for i, stmt in enumerate(node.body):
79+
is_top_level_import = isinstance(stmt, cst.SimpleStatementLine) and any(
80+
isinstance(child, (cst.Import, cst.ImportFrom)) for child in stmt.body
81+
)
82+
83+
is_conditional_import = isinstance(stmt, cst.If) and all(
84+
isinstance(inner, cst.SimpleStatementLine)
85+
and all(isinstance(child, (cst.Import, cst.ImportFrom)) for child in inner.body)
86+
for inner in stmt.body.body
87+
)
88+
89+
if is_top_level_import or is_conditional_import:
90+
insert_index = i + 1
91+
92+
# Stop scanning once we reach a class or function definition.
93+
# Imports are supposed to be at the top of the file, but they can technically appear anywhere, even at the bottom of the file.
94+
# Without this check, a stray import later in the file
95+
# would incorrectly shift our insertion index below actual code definitions.
96+
if isinstance(stmt, (cst.ClassDef, cst.FunctionDef)):
97+
break
98+
99+
return insert_index
100+
101+
75102
class GlobalAssignmentTransformer(cst.CSTTransformer):
76103
"""Transforms global assignments in the original file with those from the new file."""
77104

@@ -122,32 +149,6 @@ def leave_Assign(self, original_node: cst.Assign, updated_node: cst.Assign) -> c
122149

123150
return updated_node
124151

125-
def _find_insertion_index(self, updated_node: cst.Module) -> int:
126-
"""Find the position of the last import statement in the top-level of the module."""
127-
insert_index = 0
128-
for i, stmt in enumerate(updated_node.body):
129-
is_top_level_import = isinstance(stmt, cst.SimpleStatementLine) and any(
130-
isinstance(child, (cst.Import, cst.ImportFrom)) for child in stmt.body
131-
)
132-
133-
is_conditional_import = isinstance(stmt, cst.If) and all(
134-
isinstance(inner, cst.SimpleStatementLine)
135-
and all(isinstance(child, (cst.Import, cst.ImportFrom)) for child in inner.body)
136-
for inner in stmt.body.body
137-
)
138-
139-
if is_top_level_import or is_conditional_import:
140-
insert_index = i + 1
141-
142-
# Stop scanning once we reach a class or function definition.
143-
# Imports are supposed to be at the top of the file, but they can technically appear anywhere, even at the bottom of the file.
144-
# Without this check, a stray import later in the file
145-
# would incorrectly shift our insertion index below actual code definitions.
146-
if isinstance(stmt, (cst.ClassDef, cst.FunctionDef)):
147-
break
148-
149-
return insert_index
150-
151152
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
152153
# Add any new assignments that weren't in the original file
153154
new_statements = list(updated_node.body)
@@ -161,7 +162,7 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
161162

162163
if assignments_to_append:
163164
# after last top-level imports
164-
insert_index = self._find_insertion_index(updated_node)
165+
insert_index = find_insertion_index_after_imports(updated_node)
165166

166167
assignment_lines = [
167168
cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()])

codeflash/code_utils/code_replacer.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,18 @@
33
import ast
44
from collections import defaultdict
55
from functools import lru_cache
6+
from itertools import chain
67
from typing import TYPE_CHECKING, Optional, TypeVar
78

89
import libcst as cst
910
from libcst.metadata import PositionProvider
1011

1112
from codeflash.cli_cmds.console import logger
12-
from codeflash.code_utils.code_extractor import add_global_assignments, add_needed_imports_from_module
13+
from codeflash.code_utils.code_extractor import (
14+
add_global_assignments,
15+
add_needed_imports_from_module,
16+
find_insertion_index_after_imports,
17+
)
1318
from codeflash.code_utils.config_parser import find_conftest_files
1419
from codeflash.code_utils.formatter import sort_imports
1520
from codeflash.code_utils.line_profile_utils import ImportAdder
@@ -249,6 +254,7 @@ def __init__(
249254
] = {} # keys are (class_name, function_name)
250255
self.new_functions: list[cst.FunctionDef] = []
251256
self.new_class_functions: dict[str, list[cst.FunctionDef]] = defaultdict(list)
257+
self.new_classes: list[cst.ClassDef] = []
252258
self.current_class = None
253259
self.modified_init_functions: dict[str, cst.FunctionDef] = {}
254260

@@ -271,6 +277,10 @@ def visit_ClassDef(self, node: cst.ClassDef) -> bool:
271277
self.current_class = node.name.value
272278

273279
parents = (FunctionParent(name=node.name.value, type="ClassDef"),)
280+
281+
if (node.name.value, ()) not in self.preexisting_objects:
282+
self.new_classes.append(node)
283+
274284
for child_node in node.body.body:
275285
if (
276286
self.preexisting_objects
@@ -290,13 +300,15 @@ class OptimFunctionReplacer(cst.CSTTransformer):
290300
def __init__(
291301
self,
292302
modified_functions: Optional[dict[tuple[str | None, str], cst.FunctionDef]] = None,
303+
new_classes: Optional[list[cst.ClassDef]] = None,
293304
new_functions: Optional[list[cst.FunctionDef]] = None,
294305
new_class_functions: Optional[dict[str, list[cst.FunctionDef]]] = None,
295306
modified_init_functions: Optional[dict[str, cst.FunctionDef]] = None,
296307
) -> None:
297308
super().__init__()
298309
self.modified_functions = modified_functions if modified_functions is not None else {}
299310
self.new_functions = new_functions if new_functions is not None else []
311+
self.new_classes = new_classes if new_classes is not None else []
300312
self.new_class_functions = new_class_functions if new_class_functions is not None else defaultdict(list)
301313
self.modified_init_functions: dict[str, cst.FunctionDef] = (
302314
modified_init_functions if modified_init_functions is not None else {}
@@ -335,19 +347,33 @@ def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef
335347
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
336348
node = updated_node
337349
max_function_index = None
338-
class_index = None
350+
max_class_index = None
339351
for index, _node in enumerate(node.body):
340352
if isinstance(_node, cst.FunctionDef):
341353
max_function_index = index
342354
if isinstance(_node, cst.ClassDef):
343-
class_index = index
355+
max_class_index = index
356+
357+
if self.new_classes:
358+
existing_class_names = {_node.name.value for _node in node.body if isinstance(_node, cst.ClassDef)}
359+
360+
unique_classes = [
361+
new_class for new_class in self.new_classes if new_class.name.value not in existing_class_names
362+
]
363+
if unique_classes:
364+
new_classes_insertion_idx = max_class_index or find_insertion_index_after_imports(node)
365+
new_body = list(
366+
chain(node.body[:new_classes_insertion_idx], unique_classes, node.body[new_classes_insertion_idx:])
367+
)
368+
node = node.with_changes(body=new_body)
369+
344370
if max_function_index is not None:
345371
node = node.with_changes(
346372
body=(*node.body[: max_function_index + 1], *self.new_functions, *node.body[max_function_index + 1 :])
347373
)
348-
elif class_index is not None:
374+
elif max_class_index is not None:
349375
node = node.with_changes(
350-
body=(*node.body[: class_index + 1], *self.new_functions, *node.body[class_index + 1 :])
376+
body=(*node.body[: max_class_index + 1], *self.new_functions, *node.body[max_class_index + 1 :])
351377
)
352378
else:
353379
node = node.with_changes(body=(*self.new_functions, *node.body))
@@ -373,18 +399,20 @@ def replace_functions_in_file(
373399
parsed_function_names.append((class_name, function_name))
374400

375401
# Collect functions we want to modify from the optimized code
376-
module = cst.metadata.MetadataWrapper(cst.parse_module(optimized_code))
402+
optimized_module = cst.metadata.MetadataWrapper(cst.parse_module(optimized_code))
403+
original_module = cst.parse_module(source_code)
404+
377405
visitor = OptimFunctionCollector(preexisting_objects, set(parsed_function_names))
378-
module.visit(visitor)
406+
optimized_module.visit(visitor)
379407

380408
# Replace these functions in the original code
381409
transformer = OptimFunctionReplacer(
382410
modified_functions=visitor.modified_functions,
411+
new_classes=visitor.new_classes,
383412
new_functions=visitor.new_functions,
384413
new_class_functions=visitor.new_class_functions,
385414
modified_init_functions=visitor.modified_init_functions,
386415
)
387-
original_module = cst.parse_module(source_code)
388416
modified_tree = original_module.visit(transformer)
389417
return modified_tree.code
390418

tests/test_code_replacement.py

Lines changed: 133 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def new_function(self, value: cst.Name):
215215
return other_function(self.name)
216216
def new_function2(value):
217217
return value
218-
"""
218+
"""
219219

220220
original_code = """import libcst as cst
221221
from typing import Mandatory
@@ -230,19 +230,28 @@ def other_function(st):
230230
231231
print("Salut monde")
232232
"""
233-
expected = """from typing import Mandatory
233+
expected = """import libcst as cst
234+
from typing import Mandatory
235+
236+
class NewClass:
237+
def __init__(self, name):
238+
self.name = name
239+
def new_function(self, value: cst.Name):
240+
return other_function(self.name)
241+
def new_function2(value):
242+
return value
234243
235244
print("Au revoir")
236245
237246
def yet_another_function(values):
238247
return len(values)
239248
240-
def other_function(st):
241-
return(st * 2)
242-
243249
def totally_new_function(value):
244250
return value
245251
252+
def other_function(st):
253+
return(st * 2)
254+
246255
print("Salut monde")
247256
"""
248257

@@ -279,7 +288,7 @@ def new_function(self, value):
279288
return other_function(self.name)
280289
def new_function2(value):
281290
return value
282-
"""
291+
"""
283292

284293
original_code = """import libcst as cst
285294
from typing import Mandatory
@@ -296,17 +305,25 @@ def other_function(st):
296305
"""
297306
expected = """from typing import Mandatory
298307
308+
class NewClass:
309+
def __init__(self, name):
310+
self.name = name
311+
def new_function(self, value):
312+
return other_function(self.name)
313+
def new_function2(value):
314+
return value
315+
299316
print("Au revoir")
300317
301318
def yet_another_function(values):
302319
return len(values) + 2
303320
304-
def other_function(st):
305-
return(st * 2)
306-
307321
def totally_new_function(value):
308322
return value
309323
324+
def other_function(st):
325+
return(st * 2)
326+
310327
print("Salut monde")
311328
"""
312329

@@ -3619,4 +3636,110 @@ async def task():
36193636
await asyncio.sleep(1)
36203637
return "done"
36213638
'''
3622-
assert is_zero_diff(original_code, optimized_code)
3639+
assert is_zero_diff(original_code, optimized_code)
3640+
3641+
3642+
3643+
def test_code_replacement_with_new_helper_class() -> None:
3644+
optim_code = """from __future__ import annotations
3645+
3646+
import itertools
3647+
import re
3648+
from dataclasses import dataclass
3649+
from typing import Any, Callable, Iterator, Sequence
3650+
3651+
from bokeh.models import HoverTool, Plot, Tool
3652+
3653+
3654+
# Move the Item dataclass to module-level to avoid redefining it on every function call
3655+
@dataclass(frozen=True)
3656+
class _RepeatedToolItem:
3657+
obj: Tool
3658+
properties: dict[str, Any]
3659+
3660+
def _collect_repeated_tools(tool_objs: list[Tool]) -> Iterator[Tool]:
3661+
key: Callable[[Tool], str] = lambda obj: obj.__class__.__name__
3662+
# Pre-collect properties for all objects by group to avoid repeated calls
3663+
for _, group in itertools.groupby(sorted(tool_objs, key=key), key=key):
3664+
grouped = list(group)
3665+
n = len(grouped)
3666+
if n > 1:
3667+
# Precompute all properties once for this group
3668+
props = [_RepeatedToolItem(obj, obj.properties_with_values()) for obj in grouped]
3669+
i = 0
3670+
while i < len(props) - 1:
3671+
head = props[i]
3672+
for j in range(i+1, len(props)):
3673+
item = props[j]
3674+
if item.properties == head.properties:
3675+
yield item.obj
3676+
i += 1
3677+
"""
3678+
3679+
original_code = """from __future__ import annotations
3680+
import itertools
3681+
import re
3682+
from bokeh.models import HoverTool, Plot, Tool
3683+
from dataclasses import dataclass
3684+
from typing import Any, Callable, Iterator, Sequence
3685+
3686+
def _collect_repeated_tools(tool_objs: list[Tool]) -> Iterator[Tool]:
3687+
@dataclass(frozen=True)
3688+
class Item:
3689+
obj: Tool
3690+
properties: dict[str, Any]
3691+
3692+
key: Callable[[Tool], str] = lambda obj: obj.__class__.__name__
3693+
3694+
for _, group in itertools.groupby(sorted(tool_objs, key=key), key=key):
3695+
rest = [ Item(obj, obj.properties_with_values()) for obj in group ]
3696+
while len(rest) > 1:
3697+
head, *rest = rest
3698+
for item in rest:
3699+
if item.properties == head.properties:
3700+
yield item.obj
3701+
"""
3702+
3703+
expected = """from __future__ import annotations
3704+
import itertools
3705+
from bokeh.models import Tool
3706+
from dataclasses import dataclass
3707+
from typing import Any, Callable, Iterator
3708+
3709+
3710+
# Move the Item dataclass to module-level to avoid redefining it on every function call
3711+
@dataclass(frozen=True)
3712+
class _RepeatedToolItem:
3713+
obj: Tool
3714+
properties: dict[str, Any]
3715+
3716+
def _collect_repeated_tools(tool_objs: list[Tool]) -> Iterator[Tool]:
3717+
key: Callable[[Tool], str] = lambda obj: obj.__class__.__name__
3718+
# Pre-collect properties for all objects by group to avoid repeated calls
3719+
for _, group in itertools.groupby(sorted(tool_objs, key=key), key=key):
3720+
grouped = list(group)
3721+
n = len(grouped)
3722+
if n > 1:
3723+
# Precompute all properties once for this group
3724+
props = [_RepeatedToolItem(obj, obj.properties_with_values()) for obj in grouped]
3725+
i = 0
3726+
while i < len(props) - 1:
3727+
head = props[i]
3728+
for j in range(i+1, len(props)):
3729+
item = props[j]
3730+
if item.properties == head.properties:
3731+
yield item.obj
3732+
i += 1
3733+
"""
3734+
3735+
function_names: list[str] = ["_collect_repeated_tools"]
3736+
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code)
3737+
new_code: str = replace_functions_and_add_imports(
3738+
source_code=original_code,
3739+
function_names=function_names,
3740+
optimized_code=optim_code,
3741+
module_abspath=Path(__file__).resolve(),
3742+
preexisting_objects=preexisting_objects,
3743+
project_root_path=Path(__file__).resolve().parent.resolve(),
3744+
)
3745+
assert new_code == expected

0 commit comments

Comments
 (0)