Skip to content

Commit 2d886e8

Browse files
Merge pull request #763 from codeflash-ai/fix/correctly-find-funtion-node-when-reverting-helpers
[FIX] Respect parent classes in revert helpers
2 parents f296a0f + 76123d0 commit 2d886e8

File tree

2 files changed

+179
-6
lines changed

2 files changed

+179
-6
lines changed

codeflash/context/unused_definition_remover.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,34 @@ def _analyze_imports_in_optimized_code(
612612
return dict(imported_names_map)
613613

614614

615+
def find_target_node(
616+
root: ast.AST, function_to_optimize: FunctionToOptimize
617+
) -> Optional[ast.FunctionDef | ast.AsyncFunctionDef]:
618+
parents = function_to_optimize.parents
619+
node = root
620+
for parent in parents:
621+
# Fast loop: directly look for the matching ClassDef in node.body
622+
body = getattr(node, "body", None)
623+
if not body:
624+
return None
625+
for child in body:
626+
if isinstance(child, ast.ClassDef) and child.name == parent.name:
627+
node = child
628+
break
629+
else:
630+
return None
631+
632+
# Now node is either the root or the target parent class; look for function
633+
body = getattr(node, "body", None)
634+
if not body:
635+
return None
636+
target_name = function_to_optimize.function_name
637+
for child in body:
638+
if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)) and child.name == target_name:
639+
return child
640+
return None
641+
642+
615643
def detect_unused_helper_functions(
616644
function_to_optimize: FunctionToOptimize,
617645
code_context: CodeOptimizationContext,
@@ -641,11 +669,7 @@ def detect_unused_helper_functions(
641669
optimized_ast = ast.parse(optimized_code)
642670

643671
# Find the optimized entrypoint function
644-
entrypoint_function_ast = None
645-
for node in ast.walk(optimized_ast):
646-
if isinstance(node, ast.FunctionDef) and node.name == function_to_optimize.function_name:
647-
entrypoint_function_ast = node
648-
break
672+
entrypoint_function_ast = find_target_node(optimized_ast, function_to_optimize)
649673

650674
if not entrypoint_function_ast:
651675
logger.debug(f"Could not find entrypoint function {function_to_optimize.function_name} in optimized code")

tests/test_unused_helper_revert.py

Lines changed: 150 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pytest
77
from codeflash.context.unused_definition_remover import detect_unused_helper_functions
88
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
9-
from codeflash.models.models import CodeStringsMarkdown
9+
from codeflash.models.models import CodeStringsMarkdown, FunctionParent
1010
from codeflash.optimization.function_optimizer import FunctionOptimizer
1111
from codeflash.verification.verification_utils import TestConfig
1212

@@ -1460,3 +1460,152 @@ def calculate_class(cls, n):
14601460
import shutil
14611461

14621462
shutil.rmtree(temp_dir, ignore_errors=True)
1463+
1464+
1465+
def test_unused_helper_detection_with_duplicated_function_name_in_different_classes():
1466+
"""Test detection when helpers are called via module.function style."""
1467+
temp_dir = Path(tempfile.mkdtemp())
1468+
1469+
try:
1470+
# Main file
1471+
main_file = temp_dir / "main.py"
1472+
main_file.write_text("""from __future__ import annotations
1473+
import json
1474+
from helpers import replace_quotes_with_backticks, simplify_worktree_paths
1475+
from dataclasses import asdict, dataclass
1476+
1477+
@dataclass
1478+
class LspMessage:
1479+
1480+
def serialize(self) -> str:
1481+
data = self._loop_through(asdict(self))
1482+
# Important: keep type as the first key, for making it easy and fast for the client to know if this is a lsp message before parsing it
1483+
ordered = {"type": self.type(), **data}
1484+
return (
1485+
message_delimiter
1486+
+ json.dumps(ordered)
1487+
+ message_delimiter
1488+
)
1489+
1490+
1491+
@dataclass
1492+
class LspMarkdownMessage(LspMessage):
1493+
1494+
def serialize(self) -> str:
1495+
self.markdown = simplify_worktree_paths(self.markdown)
1496+
self.markdown = replace_quotes_with_backticks(self.markdown)
1497+
return super().serialize()
1498+
""")
1499+
1500+
# Helpers file
1501+
helpers_file = temp_dir / "helpers.py"
1502+
helpers_file.write_text("""def simplify_worktree_paths(msg: str, highlight: bool = True) -> str: # noqa: FBT001, FBT002
1503+
path_in_msg = worktree_path_regex.search(msg)
1504+
if path_in_msg:
1505+
last_part_of_path = path_in_msg.group(0).split("/")[-1]
1506+
if highlight:
1507+
last_part_of_path = f"`{last_part_of_path}`"
1508+
return msg.replace(path_in_msg.group(0), last_part_of_path)
1509+
return msg
1510+
1511+
1512+
def replace_quotes_with_backticks(text: str) -> str:
1513+
# double-quoted strings
1514+
text = _double_quote_pat.sub(r"`\1`", text)
1515+
# single-quoted strings
1516+
return _single_quote_pat.sub(r"`\1`", text)
1517+
""")
1518+
1519+
# Optimized version that only uses add_numbers
1520+
optimized_code = """
1521+
```python:main.py
1522+
from __future__ import annotations
1523+
1524+
import json
1525+
from dataclasses import asdict, dataclass
1526+
1527+
from codeflash.lsp.helpers import (replace_quotes_with_backticks,
1528+
simplify_worktree_paths)
1529+
1530+
1531+
@dataclass
1532+
class LspMessage:
1533+
1534+
def serialize(self) -> str:
1535+
# Use local variable to minimize lookup costs and avoid unnecessary dictionary unpacking
1536+
data = self._loop_through(asdict(self))
1537+
msg_type = self.type()
1538+
ordered = {'type': msg_type}
1539+
ordered.update(data)
1540+
return (
1541+
message_delimiter
1542+
+ json.dumps(ordered)
1543+
+ message_delimiter # \u241F is the message delimiter becuase it can be more than one message sent over the same message, so we need something to separate each message
1544+
)
1545+
1546+
@dataclass
1547+
class LspMarkdownMessage(LspMessage):
1548+
1549+
def serialize(self) -> str:
1550+
# Side effect required, must preserve for behavioral correctness
1551+
self.markdown = simplify_worktree_paths(self.markdown)
1552+
self.markdown = replace_quotes_with_backticks(self.markdown)
1553+
return super().serialize()
1554+
```
1555+
```python:helpers.py
1556+
def simplify_worktree_paths(msg: str, highlight: bool = True) -> str: # noqa: FBT001, FBT002
1557+
m = worktree_path_regex.search(msg)
1558+
if m:
1559+
# More efficient way to get last path part
1560+
last_part_of_path = m.group(0).rpartition('/')[-1]
1561+
if highlight:
1562+
last_part_of_path = f"`{last_part_of_path}`"
1563+
return msg.replace(m.group(0), last_part_of_path)
1564+
return msg
1565+
1566+
def replace_quotes_with_backticks(text: str) -> str:
1567+
# Efficient string substitution, reduces intermediate string allocations
1568+
return _single_quote_pat.sub(
1569+
r"`\1`",
1570+
_double_quote_pat.sub(r"`\1`", text),
1571+
)
1572+
```
1573+
"""
1574+
1575+
# Create test config
1576+
test_cfg = TestConfig(
1577+
tests_root=temp_dir / "tests",
1578+
tests_project_rootdir=temp_dir,
1579+
project_root_path=temp_dir,
1580+
test_framework="pytest",
1581+
pytest_cmd="pytest",
1582+
)
1583+
1584+
# Create FunctionToOptimize instance
1585+
function_to_optimize = FunctionToOptimize(
1586+
file_path=main_file, function_name="serialize", qualified_name="serialize", parents=[
1587+
FunctionParent(name="LspMarkdownMessage", type="ClassDef"),
1588+
]
1589+
)
1590+
1591+
optimizer = FunctionOptimizer(
1592+
function_to_optimize=function_to_optimize,
1593+
test_cfg=test_cfg,
1594+
function_to_optimize_source_code=main_file.read_text(),
1595+
)
1596+
1597+
ctx_result = optimizer.get_code_optimization_context()
1598+
assert ctx_result.is_successful(), f"Failed to get context: {ctx_result.failure()}"
1599+
1600+
code_context = ctx_result.unwrap()
1601+
1602+
unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code))
1603+
1604+
unused_names = {uh.qualified_name for uh in unused_helpers}
1605+
assert len(unused_names) == 0 # no unused helpers
1606+
1607+
finally:
1608+
# Cleanup
1609+
import shutil
1610+
1611+
shutil.rmtree(temp_dir, ignore_errors=True)

0 commit comments

Comments
 (0)