Skip to content

Commit aab93ef

Browse files
Merge branch 'main' into parallel-pytest-tracing
2 parents 060a78c + 922f714 commit aab93ef

File tree

5 files changed

+202
-30
lines changed

5 files changed

+202
-30
lines changed

codeflash/code_utils/code_extractor.py

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -335,12 +335,12 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
335335
return updated_node
336336

337337

338-
def extract_global_statements(source_code: str) -> list[cst.SimpleStatementLine]:
338+
def extract_global_statements(source_code: str) -> tuple[cst.Module, list[cst.SimpleStatementLine]]:
339339
"""Extract global statements from source code."""
340340
module = cst.parse_module(source_code)
341341
collector = GlobalStatementCollector()
342342
module.visit(collector)
343-
return collector.global_statements
343+
return module, collector.global_statements
344344

345345

346346
def find_last_import_line(target_code: str) -> int:
@@ -373,30 +373,41 @@ def delete___future___aliased_imports(module_code: str) -> str:
373373

374374

375375
def add_global_assignments(src_module_code: str, dst_module_code: str) -> str:
376-
non_assignment_global_statements = extract_global_statements(src_module_code)
376+
src_module, new_added_global_statements = extract_global_statements(src_module_code)
377+
dst_module, existing_global_statements = extract_global_statements(dst_module_code)
377378

378-
# Find the last import line in target
379-
last_import_line = find_last_import_line(dst_module_code)
380-
381-
# Parse the target code
382-
target_module = cst.parse_module(dst_module_code)
383-
384-
# Create transformer to insert non_assignment_global_statements
385-
transformer = ImportInserter(non_assignment_global_statements, last_import_line)
386-
#
387-
# # Apply transformation
388-
modified_module = target_module.visit(transformer)
389-
dst_module_code = modified_module.code
390-
391-
# Parse the code
392-
original_module = cst.parse_module(dst_module_code)
393-
new_module = cst.parse_module(src_module_code)
379+
unique_global_statements = []
380+
for stmt in new_added_global_statements:
381+
if any(
382+
stmt is existing_stmt or stmt.deep_equals(existing_stmt) for existing_stmt in existing_global_statements
383+
):
384+
continue
385+
unique_global_statements.append(stmt)
386+
387+
mod_dst_code = dst_module_code
388+
# Insert unique global statements if any
389+
if unique_global_statements:
390+
last_import_line = find_last_import_line(dst_module_code)
391+
# Reuse already-parsed dst_module
392+
transformer = ImportInserter(unique_global_statements, last_import_line)
393+
# Use visit inplace, don't parse again
394+
modified_module = dst_module.visit(transformer)
395+
mod_dst_code = modified_module.code
396+
# Parse the code after insertion
397+
original_module = cst.parse_module(mod_dst_code)
398+
else:
399+
# No new statements to insert, reuse already-parsed dst_module
400+
original_module = dst_module
394401

402+
# Parse the src_module_code once only (already done above: src_module)
395403
# Collect assignments from the new file
396404
new_collector = GlobalAssignmentCollector()
397-
new_module.visit(new_collector)
405+
src_module.visit(new_collector)
406+
# Only create transformer if there are assignments to insert/transform
407+
if not new_collector.assignments: # nothing to transform
408+
return mod_dst_code
398409

399-
# Transform the original file
410+
# Transform the original destination module
400411
transformer = GlobalAssignmentTransformer(new_collector.assignments, new_collector.assignment_order)
401412
transformed_module = original_module.visit(transformer)
402413

codeflash/code_utils/code_replacer.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,11 +412,17 @@ def replace_function_definitions_in_module(
412412
module_abspath: Path,
413413
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]],
414414
project_root_path: Path,
415+
should_add_global_assignments: bool = True, # noqa: FBT001, FBT002
415416
) -> bool:
416417
source_code: str = module_abspath.read_text(encoding="utf8")
417418
code_to_apply = get_optimized_code_for_module(module_abspath.relative_to(project_root_path), optimized_code)
419+
418420
new_code: str = replace_functions_and_add_imports(
419-
add_global_assignments(code_to_apply, source_code),
421+
# adding the global assignments before replacing the code, not after
422+
# becuase of an "edge case" where the optimized code intoduced a new import and a global assignment using that import
423+
# and that import wasn't used before, so it was ignored when calling AddImportsVisitor.add_needed_import inside replace_functions_and_add_imports (because the global assignment wasn't added yet)
424+
# this was added at https://github.com/codeflash-ai/codeflash/pull/448
425+
add_global_assignments(code_to_apply, source_code) if should_add_global_assignments else source_code,
420426
function_names,
421427
code_to_apply,
422428
module_abspath,

codeflash/context/unused_definition_remover.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,7 @@ def revert_unused_helper_functions(
537537
module_abspath=file_path,
538538
preexisting_objects=set(), # Empty set since we're reverting
539539
project_root_path=project_root,
540+
should_add_global_assignments=False, # since we revert helpers functions after applying the optimization, we know that the file already has global assignments added, otherwise they would be added twice.
540541
)
541542

542543
if reverted_code:

codeflash/optimization/function_optimizer.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
)
5454
from codeflash.code_utils.env_utils import get_pr_number
5555
from codeflash.code_utils.formatter import format_code, sort_imports
56+
from codeflash.code_utils.git_utils import git_root_dir
5657
from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test
5758
from codeflash.code_utils.line_profile_utils import add_decorator_imports
5859
from codeflash.code_utils.static_analysis import get_first_top_level_function_or_method_ast
@@ -820,7 +821,10 @@ def reformat_code_and_helpers(
820821
return new_code, new_helper_code
821822

822823
def replace_function_and_helpers_with_optimized_code(
823-
self, code_context: CodeOptimizationContext, optimized_code: CodeStringsMarkdown, original_helper_code: str
824+
self,
825+
code_context: CodeOptimizationContext,
826+
optimized_code: CodeStringsMarkdown,
827+
original_helper_code: dict[Path, str],
824828
) -> bool:
825829
did_update = False
826830
read_writable_functions_by_file_path = defaultdict(set)
@@ -1298,11 +1302,13 @@ def process_review(
12981302
"coverage_message": coverage_message,
12991303
"replay_tests": replay_tests,
13001304
"concolic_tests": concolic_tests,
1301-
"root_dir": self.project_root,
13021305
}
13031306

13041307
raise_pr = not self.args.no_pr
13051308

1309+
if raise_pr or self.args.staging_review:
1310+
data["root_dir"] = git_root_dir()
1311+
13061312
if raise_pr and not self.args.staging_review:
13071313
data["git_remote"] = self.args.git_remote
13081314
check_create_pr(**data)

tests/test_code_replacement.py

Lines changed: 154 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1707,7 +1707,6 @@ def new_function2(value):
17071707
"""
17081708
expected_code = """import numpy as np
17091709
1710-
print("Hello world")
17111710
a=2
17121711
print("Hello world")
17131712
def some_fn():
@@ -1783,7 +1782,6 @@ def new_function2(value):
17831782
"""
17841783
expected_code = """import numpy as np
17851784
1786-
print("Hello world")
17871785
print("Hello world")
17881786
def some_fn():
17891787
a=np.zeros(10)
@@ -1862,7 +1860,6 @@ def new_function2(value):
18621860
"""
18631861
expected_code = """import numpy as np
18641862
1865-
print("Hello world")
18661863
a=3
18671864
print("Hello world")
18681865
def some_fn():
@@ -1940,7 +1937,6 @@ def new_function2(value):
19401937
"""
19411938
expected_code = """import numpy as np
19421939
1943-
print("Hello world")
19441940
a=2
19451941
print("Hello world")
19461942
def some_fn():
@@ -2019,7 +2015,6 @@ def new_function2(value):
20192015
"""
20202016
expected_code = """import numpy as np
20212017
2022-
print("Hello world")
20232018
a=3
20242019
print("Hello world")
20252020
def some_fn():
@@ -2106,7 +2101,6 @@ def new_function2(value):
21062101
21072102
a = 6
21082103
2109-
print("Hello world")
21102104
if 2<3:
21112105
a=4
21122106
else:
@@ -3453,3 +3447,157 @@ def hydrate_input_text_actions_with_field_names(
34533447
main_file.unlink(missing_ok=True)
34543448

34553449
assert new_code == expected
3450+
3451+
def test_duplicate_global_assignments_when_reverting_helpers():
3452+
root_dir = Path(__file__).parent.parent.resolve()
3453+
main_file = Path(root_dir / "code_to_optimize/temp_main.py").resolve()
3454+
3455+
original_code = '''"""Chunking objects not specific to a particular chunking strategy."""
3456+
from __future__ import annotations
3457+
import collections
3458+
import copy
3459+
from typing import Any, Callable, DefaultDict, Iterable, Iterator, cast
3460+
import regex
3461+
from typing_extensions import Self, TypeAlias
3462+
from unstructured.utils import lazyproperty
3463+
from unstructured.documents.elements import Element
3464+
# ================================================================================================
3465+
# MODEL
3466+
# ================================================================================================
3467+
CHUNK_MAX_CHARS_DEFAULT: int = 500
3468+
# ================================================================================================
3469+
# PRE-CHUNKER
3470+
# ================================================================================================
3471+
class PreChunker:
3472+
"""Gathers sequential elements into pre-chunks as length constraints allow.
3473+
The pre-chunker's responsibilities are:
3474+
- **Segregate semantic units.** Identify semantic unit boundaries and segregate elements on
3475+
either side of those boundaries into different sections. In this case, the primary indicator
3476+
of a semantic boundary is a `Title` element. A page-break (change in page-number) is also a
3477+
semantic boundary when `multipage_sections` is `False`.
3478+
- **Minimize chunk count for each semantic unit.** Group the elements within a semantic unit
3479+
into sections as big as possible without exceeding the chunk window size.
3480+
- **Minimize chunks that must be split mid-text.** Precompute the text length of each section
3481+
and only produce a section that exceeds the chunk window size when there is a single element
3482+
with text longer than that window.
3483+
A Table element is placed into a section by itself. CheckBox elements are dropped.
3484+
The "by-title" strategy specifies breaking on section boundaries; a `Title` element indicates
3485+
a new "section", hence the "by-title" designation.
3486+
"""
3487+
def __init__(self, elements: Iterable[Element], opts: ChunkingOptions):
3488+
self._elements = elements
3489+
self._opts = opts
3490+
@lazyproperty
3491+
def _boundary_predicates(self) -> tuple[BoundaryPredicate, ...]:
3492+
"""The semantic-boundary detectors to be applied to break pre-chunks."""
3493+
return self._opts.boundary_predicates
3494+
def _is_in_new_semantic_unit(self, element: Element) -> bool:
3495+
"""True when `element` begins a new semantic unit such as a section or page."""
3496+
# -- all detectors need to be called to update state and avoid double counting
3497+
# -- boundaries that happen to coincide, like Table and new section on same element.
3498+
# -- Using `any()` would short-circuit on first True.
3499+
semantic_boundaries = [pred(element) for pred in self._boundary_predicates]
3500+
return any(semantic_boundaries)
3501+
'''
3502+
main_file.write_text(original_code, encoding="utf-8")
3503+
optim_code = f'''```python:{main_file.relative_to(root_dir)}
3504+
# ================================================================================================
3505+
# PRE-CHUNKER
3506+
# ================================================================================================
3507+
from __future__ import annotations
3508+
from typing import Iterable
3509+
from unstructured.documents.elements import Element
3510+
from unstructured.utils import lazyproperty
3511+
class PreChunker:
3512+
def __init__(self, elements: Iterable[Element], opts: ChunkingOptions):
3513+
self._elements = elements
3514+
self._opts = opts
3515+
@lazyproperty
3516+
def _boundary_predicates(self) -> tuple[BoundaryPredicate, ...]:
3517+
"""The semantic-boundary detectors to be applied to break pre-chunks."""
3518+
return self._opts.boundary_predicates
3519+
def _is_in_new_semantic_unit(self, element: Element) -> bool:
3520+
"""True when `element` begins a new semantic unit such as a section or page."""
3521+
# Use generator expression for lower memory usage and avoid building intermediate list
3522+
for pred in self._boundary_predicates:
3523+
if pred(element):
3524+
return True
3525+
return False
3526+
```
3527+
'''
3528+
3529+
func = FunctionToOptimize(function_name="_is_in_new_semantic_unit", parents=[FunctionParent("PreChunker", "ClassDef")], file_path=main_file)
3530+
test_config = TestConfig(
3531+
tests_root=root_dir / "tests/pytest",
3532+
tests_project_rootdir=root_dir,
3533+
project_root_path=root_dir,
3534+
test_framework="pytest",
3535+
pytest_cmd="pytest",
3536+
)
3537+
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
3538+
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
3539+
3540+
original_helper_code: dict[Path, str] = {}
3541+
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
3542+
for helper_function_path in helper_function_paths:
3543+
with helper_function_path.open(encoding="utf8") as f:
3544+
helper_code = f.read()
3545+
original_helper_code[helper_function_path] = helper_code
3546+
3547+
func_optimizer.args = Args()
3548+
func_optimizer.replace_function_and_helpers_with_optimized_code(
3549+
code_context=code_context, optimized_code=CodeStringsMarkdown.parse_markdown_code(optim_code), original_helper_code=original_helper_code
3550+
)
3551+
3552+
3553+
new_code = main_file.read_text(encoding="utf-8")
3554+
main_file.unlink(missing_ok=True)
3555+
3556+
expected = '''"""Chunking objects not specific to a particular chunking strategy."""
3557+
from __future__ import annotations
3558+
import collections
3559+
import copy
3560+
from typing import Any, Callable, DefaultDict, Iterable, Iterator, cast
3561+
import regex
3562+
from typing_extensions import Self, TypeAlias
3563+
from unstructured.utils import lazyproperty
3564+
from unstructured.documents.elements import Element
3565+
# ================================================================================================
3566+
# MODEL
3567+
# ================================================================================================
3568+
CHUNK_MAX_CHARS_DEFAULT: int = 500
3569+
# ================================================================================================
3570+
# PRE-CHUNKER
3571+
# ================================================================================================
3572+
class PreChunker:
3573+
"""Gathers sequential elements into pre-chunks as length constraints allow.
3574+
The pre-chunker's responsibilities are:
3575+
- **Segregate semantic units.** Identify semantic unit boundaries and segregate elements on
3576+
either side of those boundaries into different sections. In this case, the primary indicator
3577+
of a semantic boundary is a `Title` element. A page-break (change in page-number) is also a
3578+
semantic boundary when `multipage_sections` is `False`.
3579+
- **Minimize chunk count for each semantic unit.** Group the elements within a semantic unit
3580+
into sections as big as possible without exceeding the chunk window size.
3581+
- **Minimize chunks that must be split mid-text.** Precompute the text length of each section
3582+
and only produce a section that exceeds the chunk window size when there is a single element
3583+
with text longer than that window.
3584+
A Table element is placed into a section by itself. CheckBox elements are dropped.
3585+
The "by-title" strategy specifies breaking on section boundaries; a `Title` element indicates
3586+
a new "section", hence the "by-title" designation.
3587+
"""
3588+
def __init__(self, elements: Iterable[Element], opts: ChunkingOptions):
3589+
self._elements = elements
3590+
self._opts = opts
3591+
@lazyproperty
3592+
def _boundary_predicates(self) -> tuple[BoundaryPredicate, ...]:
3593+
"""The semantic-boundary detectors to be applied to break pre-chunks."""
3594+
return self._opts.boundary_predicates
3595+
def _is_in_new_semantic_unit(self, element: Element) -> bool:
3596+
"""True when `element` begins a new semantic unit such as a section or page."""
3597+
# Use generator expression for lower memory usage and avoid building intermediate list
3598+
for pred in self._boundary_predicates:
3599+
if pred(element):
3600+
return True
3601+
return False
3602+
'''
3603+
assert new_code == expected

0 commit comments

Comments
 (0)