Skip to content

Commit c6ab05f

Browse files
authored
Merge pull request #935 from codeflash-ai/fix/ctx-global-definitions-deps
[FIX] Extract the required module names inside the read-writable context (CF-862)
2 parents 7d18130 + da590d9 commit c6ab05f

File tree

3 files changed

+257
-73
lines changed

3 files changed

+257
-73
lines changed

codeflash/context/code_context_extractor.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@
1212
from codeflash.cli_cmds.console import logger
1313
from codeflash.code_utils.code_extractor import add_needed_imports_from_module, find_preexisting_objects
1414
from codeflash.code_utils.code_utils import encoded_tokens_len, get_qualified_name, path_belongs_to_site_packages
15-
from codeflash.context.unused_definition_remover import remove_unused_definitions_by_function_names
15+
from codeflash.context.unused_definition_remover import (
16+
collect_top_level_defs_with_usages,
17+
extract_names_from_targets,
18+
remove_unused_definitions_by_function_names,
19+
)
1620
from codeflash.discovery.functions_to_optimize import FunctionToOptimize # noqa: TC001
1721
from codeflash.models.models import (
1822
CodeContextType,
@@ -29,6 +33,8 @@
2933
from jedi.api.classes import Name
3034
from libcst import CSTNode
3135

36+
from codeflash.context.unused_definition_remover import UsageInfo
37+
3238

3339
def get_code_optimization_context(
3440
function_to_optimize: FunctionToOptimize,
@@ -498,8 +504,10 @@ def parse_code_and_prune_cst(
498504
) -> str:
499505
"""Create a read-only version of the code by parsing and filtering the code to keep only class contextual information, and other module scoped variables."""
500506
module = cst.parse_module(code)
507+
defs_with_usages = collect_top_level_defs_with_usages(module, target_functions | helpers_of_helper_functions)
508+
501509
if code_context_type == CodeContextType.READ_WRITABLE:
502-
filtered_node, found_target = prune_cst_for_read_writable_code(module, target_functions)
510+
filtered_node, found_target = prune_cst_for_read_writable_code(module, target_functions, defs_with_usages)
503511
elif code_context_type == CodeContextType.READ_ONLY:
504512
filtered_node, found_target = prune_cst_for_read_only_code(
505513
module, target_functions, helpers_of_helper_functions, remove_docstrings=remove_docstrings
@@ -524,7 +532,7 @@ def parse_code_and_prune_cst(
524532

525533

526534
def prune_cst_for_read_writable_code( # noqa: PLR0911
527-
node: cst.CSTNode, target_functions: set[str], prefix: str = ""
535+
node: cst.CSTNode, target_functions: set[str], defs_with_usages: dict[str, UsageInfo], prefix: str = ""
528536
) -> tuple[cst.CSTNode | None, bool]:
529537
"""Recursively filter the node and its children to build the read-writable codeblock. This contains nodes that lead to target functions.
530538
@@ -569,6 +577,21 @@ def prune_cst_for_read_writable_code( # noqa: PLR0911
569577

570578
return node.with_changes(body=cst.IndentedBlock(body=new_body)), found_target
571579

580+
if isinstance(node, cst.Assign):
581+
for target in node.targets:
582+
names = extract_names_from_targets(target.target)
583+
for name in names:
584+
if name in defs_with_usages and defs_with_usages[name].used_by_qualified_function:
585+
return node, True
586+
return None, False
587+
588+
if isinstance(node, (cst.AnnAssign, cst.AugAssign)):
589+
names = extract_names_from_targets(node.target)
590+
for name in names:
591+
if name in defs_with_usages and defs_with_usages[name].used_by_qualified_function:
592+
return node, True
593+
return None, False
594+
572595
# For other nodes, we preserve them only if they contain target functions in their children.
573596
section_names = get_section_names(node)
574597
if not section_names:
@@ -583,7 +606,9 @@ def prune_cst_for_read_writable_code( # noqa: PLR0911
583606
new_children = []
584607
section_found_target = False
585608
for child in original_content:
586-
filtered, found_target = prune_cst_for_read_writable_code(child, target_functions, prefix)
609+
filtered, found_target = prune_cst_for_read_writable_code(
610+
child, target_functions, defs_with_usages, prefix
611+
)
587612
if filtered:
588613
new_children.append(filtered)
589614
section_found_target |= found_target
@@ -592,15 +617,16 @@ def prune_cst_for_read_writable_code( # noqa: PLR0911
592617
found_any_target = True
593618
updates[section] = new_children
594619
elif original_content is not None:
595-
filtered, found_target = prune_cst_for_read_writable_code(original_content, target_functions, prefix)
620+
filtered, found_target = prune_cst_for_read_writable_code(
621+
original_content, target_functions, defs_with_usages, prefix
622+
)
596623
if found_target:
597624
found_any_target = True
598625
if filtered:
599626
updates[section] = filtered
600627

601628
if not found_any_target:
602629
return None, False
603-
604630
return (node.with_changes(**updates) if updates else node), True
605631

606632

codeflash/context/unused_definition_remover.py

Lines changed: 81 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from dataclasses import dataclass, field
66
from itertools import chain
77
from pathlib import Path
8-
from typing import TYPE_CHECKING, Optional
8+
from typing import TYPE_CHECKING, Optional, Union
99

1010
import libcst as cst
1111

@@ -52,46 +52,64 @@ def collect_top_level_definitions(
5252
node: cst.CSTNode, definitions: Optional[dict[str, UsageInfo]] = None
5353
) -> dict[str, UsageInfo]:
5454
"""Recursively collect all top-level variable, function, and class definitions."""
55+
# Locally bind types and helpers for faster lookup
56+
FunctionDef = cst.FunctionDef # noqa: N806
57+
ClassDef = cst.ClassDef # noqa: N806
58+
Assign = cst.Assign # noqa: N806
59+
AnnAssign = cst.AnnAssign # noqa: N806
60+
AugAssign = cst.AugAssign # noqa: N806
61+
IndentedBlock = cst.IndentedBlock # noqa: N806
62+
5563
if definitions is None:
5664
definitions = {}
5765

58-
# Handle top-level function definitions
59-
if isinstance(node, cst.FunctionDef):
66+
# Speed: Single isinstance+local var instead of several type calls
67+
node_type = type(node)
68+
# Fast path: function def
69+
if node_type is FunctionDef:
6070
name = node.name.value
6171
definitions[name] = UsageInfo(
6272
name=name,
6373
used_by_qualified_function=False, # Will be marked later if in qualified functions
6474
)
6575
return definitions
6676

67-
# Handle top-level class definitions
68-
if isinstance(node, cst.ClassDef):
77+
# Fast path: class def
78+
if node_type is ClassDef:
6979
name = node.name.value
7080
definitions[name] = UsageInfo(name=name)
7181

72-
# Also collect method definitions within the class
73-
if hasattr(node, "body") and isinstance(node.body, cst.IndentedBlock):
74-
for statement in node.body.body:
75-
if isinstance(statement, cst.FunctionDef):
76-
method_name = f"{name}.{statement.name.value}"
82+
# Collect class methods
83+
body = getattr(node, "body", None)
84+
if body is not None and type(body) is IndentedBlock:
85+
statements = body.body
86+
# Precompute f-string template for efficiency
87+
prefix = name + "."
88+
for statement in statements:
89+
if type(statement) is FunctionDef:
90+
method_name = prefix + statement.name.value
7791
definitions[method_name] = UsageInfo(name=method_name)
7892

7993
return definitions
8094

81-
# Handle top-level variable assignments
82-
if isinstance(node, cst.Assign):
83-
for target in node.targets:
95+
# Fast path: assignment
96+
if node_type is Assign:
97+
# Inline extract_names_from_targets for single-target speed
98+
targets = node.targets
99+
append_def = definitions.__setitem__
100+
for target in targets:
84101
names = extract_names_from_targets(target.target)
85102
for name in names:
86-
definitions[name] = UsageInfo(name=name)
103+
append_def(name, UsageInfo(name=name))
87104
return definitions
88105

89-
if isinstance(node, (cst.AnnAssign, cst.AugAssign)):
90-
if isinstance(node.target, cst.Name):
91-
name = node.target.value
106+
if node_type is AnnAssign or node_type is AugAssign:
107+
tgt = node.target
108+
if type(tgt) is cst.Name:
109+
name = tgt.value
92110
definitions[name] = UsageInfo(name=name)
93111
else:
94-
names = extract_names_from_targets(node.target)
112+
names = extract_names_from_targets(tgt)
95113
for name in names:
96114
definitions[name] = UsageInfo(name=name)
97115
return definitions
@@ -100,12 +118,15 @@ def collect_top_level_definitions(
100118
section_names = get_section_names(node)
101119

102120
if section_names:
121+
getattr_ = getattr
103122
for section in section_names:
104-
original_content = getattr(node, section, None)
123+
original_content = getattr_(node, section, None)
124+
# Instead of isinstance check for list/tuple, rely on duck-type via iter
105125
# If section contains a list of nodes
106126
if isinstance(original_content, (list, tuple)):
127+
defs = definitions # Move out for minor speed
107128
for child in original_content:
108-
collect_top_level_definitions(child, definitions)
129+
collect_top_level_definitions(child, defs)
109130
# If section contains a single node
110131
elif original_content is not None:
111132
collect_top_level_definitions(original_content, definitions)
@@ -122,6 +143,8 @@ def get_section_names(node: cst.CSTNode) -> list[str]:
122143
class DependencyCollector(cst.CSTVisitor):
123144
"""Collects dependencies between definitions using the visitor pattern with depth tracking."""
124145

146+
METADATA_DEPENDENCIES = (cst.metadata.ParentNodeProvider,)
147+
125148
def __init__(self, definitions: dict[str, UsageInfo]) -> None:
126149
super().__init__()
127150
self.definitions = definitions
@@ -259,8 +282,12 @@ def visit_Name(self, node: cst.Name) -> None:
259282
if self.processing_variable and name in self.current_variable_names:
260283
return
261284

262-
# Check if name is a top-level definition we're tracking
263285
if name in self.definitions and name != self.current_top_level_name:
286+
# skip if we are refrencing a class attribute and not a top-level definition
287+
if self.class_depth > 0:
288+
parent = self.get_metadata(cst.metadata.ParentNodeProvider, node)
289+
if parent is not None and isinstance(parent, cst.Attribute):
290+
return
264291
self.definitions[self.current_top_level_name].dependencies.add(name)
265292

266293

@@ -293,13 +320,20 @@ def _expand_qualified_functions(self) -> set[str]:
293320

294321
def mark_used_definitions(self) -> None:
295322
"""Find all qualified functions and mark them and their dependencies as used."""
296-
# First identify all specified functions (including expanded ones)
297-
functions_to_mark = [name for name in self.expanded_qualified_functions if name in self.definitions]
323+
# Avoid list comprehension for set intersection
324+
expanded_names = self.expanded_qualified_functions
325+
defs = self.definitions
326+
# Use set intersection but only if defs.keys is a set (Python 3.12 dict_keys supports it efficiently)
327+
fnames = (
328+
expanded_names & defs.keys()
329+
if isinstance(expanded_names, set)
330+
else [name for name in expanded_names if name in defs]
331+
)
298332

299333
# For each specified function, mark it and all its dependencies as used
300-
for func_name in functions_to_mark:
301-
self.definitions[func_name].used_by_qualified_function = True
302-
for dep in self.definitions[func_name].dependencies:
334+
for func_name in fnames:
335+
defs[func_name].used_by_qualified_function = True
336+
for dep in defs[func_name].dependencies:
303337
self.mark_as_used_recursively(dep)
304338

305339
def mark_as_used_recursively(self, name: str) -> None:
@@ -457,6 +491,25 @@ def remove_unused_definitions_recursively( # noqa: PLR0911
457491
return node, False
458492

459493

494+
def collect_top_level_defs_with_usages(
495+
code: Union[str, cst.Module], qualified_function_names: set[str]
496+
) -> dict[str, UsageInfo]:
497+
"""Collect all top level definitions (classes, variables or functions) and their usages."""
498+
module = code if isinstance(code, cst.Module) else cst.parse_module(code)
499+
# Collect all definitions (top level classes, variables or function)
500+
definitions = collect_top_level_definitions(module)
501+
502+
# Collect dependencies between definitions using the visitor pattern
503+
wrapper = cst.MetadataWrapper(module)
504+
dependency_collector = DependencyCollector(definitions)
505+
wrapper.visit(dependency_collector)
506+
507+
# Mark definitions used by specified functions, and their dependencies recursively
508+
usage_marker = QualifiedFunctionUsageMarker(definitions, qualified_function_names)
509+
usage_marker.mark_used_definitions()
510+
return definitions
511+
512+
460513
def remove_unused_definitions_by_function_names(code: str, qualified_function_names: set[str]) -> str:
461514
"""Analyze a file and remove top level definitions not used by specified functions.
462515
@@ -476,19 +529,10 @@ def remove_unused_definitions_by_function_names(code: str, qualified_function_na
476529
return code
477530

478531
try:
479-
# Collect all definitions (top level classes, variables or function)
480-
definitions = collect_top_level_definitions(module)
481-
482-
# Collect dependencies between definitions using the visitor pattern
483-
dependency_collector = DependencyCollector(definitions)
484-
module.visit(dependency_collector)
485-
486-
# Mark definitions used by specified functions, and their dependencies recursively
487-
usage_marker = QualifiedFunctionUsageMarker(definitions, qualified_function_names)
488-
usage_marker.mark_used_definitions()
532+
defs_with_usages = collect_top_level_defs_with_usages(module, qualified_function_names)
489533

490534
# Apply the recursive removal transformation
491-
modified_module, _ = remove_unused_definitions_recursively(module, definitions)
535+
modified_module, _ = remove_unused_definitions_recursively(module, defs_with_usages)
492536

493537
return modified_module.code if modified_module else "" # noqa: TRY300
494538
except Exception as e:

0 commit comments

Comments
 (0)