@@ -215,7 +215,7 @@ def new_function(self, value: cst.Name):
215215 return other_function(self.name)
216216 def new_function2(value):
217217 return value
218- """
218+ """
219219
220220 original_code = """import libcst as cst
221221from typing import Mandatory
@@ -230,19 +230,28 @@ def other_function(st):
230230
231231print("Salut monde")
232232"""
233- expected = """from typing import Mandatory
233+ expected = """import libcst as cst
234+ from typing import Mandatory
235+
236+ class NewClass:
237+ def __init__(self, name):
238+ self.name = name
239+ def new_function(self, value: cst.Name):
240+ return other_function(self.name)
241+ def new_function2(value):
242+ return value
234243
235244print("Au revoir")
236245
237246def yet_another_function(values):
238247 return len(values)
239248
240- def other_function(st):
241- return(st * 2)
242-
243249def totally_new_function(value):
244250 return value
245251
252+ def other_function(st):
253+ return(st * 2)
254+
246255print("Salut monde")
247256"""
248257
@@ -279,7 +288,7 @@ def new_function(self, value):
279288 return other_function(self.name)
280289 def new_function2(value):
281290 return value
282- """
291+ """
283292
284293 original_code = """import libcst as cst
285294from typing import Mandatory
@@ -296,17 +305,25 @@ def other_function(st):
296305"""
297306 expected = """from typing import Mandatory
298307
308+ class NewClass:
309+ def __init__(self, name):
310+ self.name = name
311+ def new_function(self, value):
312+ return other_function(self.name)
313+ def new_function2(value):
314+ return value
315+
299316print("Au revoir")
300317
301318def yet_another_function(values):
302319 return len(values) + 2
303320
304- def other_function(st):
305- return(st * 2)
306-
307321def totally_new_function(value):
308322 return value
309323
324+ def other_function(st):
325+ return(st * 2)
326+
310327print("Salut monde")
311328"""
312329
@@ -3619,4 +3636,110 @@ async def task():
36193636 await asyncio.sleep(1)
36203637 return "done"
36213638'''
3622- assert is_zero_diff (original_code , optimized_code )
3639+ assert is_zero_diff (original_code , optimized_code )
3640+
3641+
3642+
3643+ def test_code_replacement_with_new_helper_class () -> None :
3644+ optim_code = """from __future__ import annotations
3645+
3646+ import itertools
3647+ import re
3648+ from dataclasses import dataclass
3649+ from typing import Any, Callable, Iterator, Sequence
3650+
3651+ from bokeh.models import HoverTool, Plot, Tool
3652+
3653+
3654+ # Move the Item dataclass to module-level to avoid redefining it on every function call
3655+ @dataclass(frozen=True)
3656+ class _RepeatedToolItem:
3657+ obj: Tool
3658+ properties: dict[str, Any]
3659+
3660+ def _collect_repeated_tools(tool_objs: list[Tool]) -> Iterator[Tool]:
3661+ key: Callable[[Tool], str] = lambda obj: obj.__class__.__name__
3662+ # Pre-collect properties for all objects by group to avoid repeated calls
3663+ for _, group in itertools.groupby(sorted(tool_objs, key=key), key=key):
3664+ grouped = list(group)
3665+ n = len(grouped)
3666+ if n > 1:
3667+ # Precompute all properties once for this group
3668+ props = [_RepeatedToolItem(obj, obj.properties_with_values()) for obj in grouped]
3669+ i = 0
3670+ while i < len(props) - 1:
3671+ head = props[i]
3672+ for j in range(i+1, len(props)):
3673+ item = props[j]
3674+ if item.properties == head.properties:
3675+ yield item.obj
3676+ i += 1
3677+ """
3678+
3679+ original_code = """from __future__ import annotations
3680+ import itertools
3681+ import re
3682+ from bokeh.models import HoverTool, Plot, Tool
3683+ from dataclasses import dataclass
3684+ from typing import Any, Callable, Iterator, Sequence
3685+
3686+ def _collect_repeated_tools(tool_objs: list[Tool]) -> Iterator[Tool]:
3687+ @dataclass(frozen=True)
3688+ class Item:
3689+ obj: Tool
3690+ properties: dict[str, Any]
3691+
3692+ key: Callable[[Tool], str] = lambda obj: obj.__class__.__name__
3693+
3694+ for _, group in itertools.groupby(sorted(tool_objs, key=key), key=key):
3695+ rest = [ Item(obj, obj.properties_with_values()) for obj in group ]
3696+ while len(rest) > 1:
3697+ head, *rest = rest
3698+ for item in rest:
3699+ if item.properties == head.properties:
3700+ yield item.obj
3701+ """
3702+
3703+ expected = """from __future__ import annotations
3704+ import itertools
3705+ from bokeh.models import Tool
3706+ from dataclasses import dataclass
3707+ from typing import Any, Callable, Iterator
3708+
3709+
3710+ # Move the Item dataclass to module-level to avoid redefining it on every function call
3711+ @dataclass(frozen=True)
3712+ class _RepeatedToolItem:
3713+ obj: Tool
3714+ properties: dict[str, Any]
3715+
3716+ def _collect_repeated_tools(tool_objs: list[Tool]) -> Iterator[Tool]:
3717+ key: Callable[[Tool], str] = lambda obj: obj.__class__.__name__
3718+ # Pre-collect properties for all objects by group to avoid repeated calls
3719+ for _, group in itertools.groupby(sorted(tool_objs, key=key), key=key):
3720+ grouped = list(group)
3721+ n = len(grouped)
3722+ if n > 1:
3723+ # Precompute all properties once for this group
3724+ props = [_RepeatedToolItem(obj, obj.properties_with_values()) for obj in grouped]
3725+ i = 0
3726+ while i < len(props) - 1:
3727+ head = props[i]
3728+ for j in range(i+1, len(props)):
3729+ item = props[j]
3730+ if item.properties == head.properties:
3731+ yield item.obj
3732+ i += 1
3733+ """
3734+
3735+ function_names : list [str ] = ["_collect_repeated_tools" ]
3736+ preexisting_objects : set [tuple [str , tuple [FunctionParent , ...]]] = find_preexisting_objects (original_code )
3737+ new_code : str = replace_functions_and_add_imports (
3738+ source_code = original_code ,
3739+ function_names = function_names ,
3740+ optimized_code = optim_code ,
3741+ module_abspath = Path (__file__ ).resolve (),
3742+ preexisting_objects = preexisting_objects ,
3743+ project_root_path = Path (__file__ ).resolve ().parent .resolve (),
3744+ )
3745+ assert new_code == expected
0 commit comments