Skip to content

Commit 18df53f

Browse files
authored
Merge branch 'main' into temporal-python--updated
2 parents 1a64c24 + c4a19d7 commit 18df53f

File tree

8 files changed

+63
-49
lines changed

8 files changed

+63
-49
lines changed

codeflash/code_utils/coverage_utils.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,12 @@
1212

1313
def extract_dependent_function(main_function: str, code_context: CodeOptimizationContext) -> str | Literal[False]:
1414
"""Extract the single dependent function from the code context excluding the main function."""
15-
ast_tree = ast.parse(code_context.testgen_context_code)
16-
17-
dependent_functions = {
18-
node.name for node in ast_tree.body if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef))
19-
}
15+
dependent_functions = set()
16+
for code_string in code_context.testgen_context.code_strings:
17+
ast_tree = ast.parse(code_string.code)
18+
dependent_functions.update(
19+
{node.name for node in ast_tree.body if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef))}
20+
)
2021

2122
if main_function in dependent_functions:
2223
dependent_functions.discard(main_function)

codeflash/context/code_context_extractor.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -114,32 +114,32 @@ def get_code_optimization_context(
114114
read_only_context_code = ""
115115

116116
# Extract code context for testgen
117-
testgen_code_markdown = extract_code_string_context_from_files(
117+
testgen_context = extract_code_markdown_context_from_files(
118118
helpers_of_fto_dict,
119119
helpers_of_helpers_dict,
120120
project_root_path,
121121
remove_docstrings=False,
122122
code_context_type=CodeContextType.TESTGEN,
123123
)
124-
testgen_context_code = testgen_code_markdown.code
125-
testgen_context_code_tokens = encoded_tokens_len(testgen_context_code)
126-
if testgen_context_code_tokens > testgen_token_limit:
127-
testgen_code_markdown = extract_code_string_context_from_files(
124+
testgen_markdown_code = testgen_context.markdown
125+
testgen_code_token_length = encoded_tokens_len(testgen_markdown_code)
126+
if testgen_code_token_length > testgen_token_limit:
127+
testgen_context = extract_code_markdown_context_from_files(
128128
helpers_of_fto_dict,
129129
helpers_of_helpers_dict,
130130
project_root_path,
131131
remove_docstrings=True,
132132
code_context_type=CodeContextType.TESTGEN,
133133
)
134-
testgen_context_code = testgen_code_markdown.code
135-
testgen_context_code_tokens = encoded_tokens_len(testgen_context_code)
136-
if testgen_context_code_tokens > testgen_token_limit:
134+
testgen_markdown_code = testgen_context.markdown
135+
testgen_code_token_length = encoded_tokens_len(testgen_markdown_code)
136+
if testgen_code_token_length > testgen_token_limit:
137137
raise ValueError("Testgen code context has exceeded token limit, cannot proceed")
138138
code_hash_context = hashing_code_context.markdown
139139
code_hash = hashlib.sha256(code_hash_context.encode("utf-8")).hexdigest()
140140

141141
return CodeOptimizationContext(
142-
testgen_context_code=testgen_context_code,
142+
testgen_context=testgen_context,
143143
read_writable_code=final_read_writable_code,
144144
read_only_context_code=read_only_context_code,
145145
hashing_code_context=code_hash_context,

codeflash/models/models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ class CodeString(BaseModel):
163163

164164

165165
def get_code_block_splitter(file_path: Path) -> str:
166-
return f"# file: {file_path}"
166+
return f"# file: {file_path.as_posix()}"
167167

168168

169169
markdown_pattern = re.compile(r"```python:([^\n]+)\n(.*?)\n```", re.DOTALL)
@@ -254,7 +254,7 @@ def parse_markdown_code(markdown_code: str) -> CodeStringsMarkdown:
254254

255255

256256
class CodeOptimizationContext(BaseModel):
257-
testgen_context_code: str = ""
257+
testgen_context: CodeStringsMarkdown
258258
read_writable_code: CodeStringsMarkdown
259259
read_only_context_code: str = ""
260260
hashing_code_context: str = ""

codeflash/optimization/function_optimizer.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ def generate_and_instrument_tests(
309309
revert_to_print=bool(get_pr_number()),
310310
):
311311
generated_results = self.generate_tests_and_optimizations(
312-
testgen_context_code=code_context.testgen_context_code,
312+
testgen_context=code_context.testgen_context,
313313
read_writable_code=code_context.read_writable_code,
314314
read_only_context_code=code_context.read_only_context_code,
315315
helper_functions=code_context.helper_functions,
@@ -345,7 +345,6 @@ def generate_and_instrument_tests(
345345
logger.info(f"Generated test {i + 1}/{count_tests}:")
346346
code_print(generated_test.generated_original_test_source, file_name=f"test_{i + 1}.py")
347347
if concolic_test_str:
348-
# no concolic tests in lsp mode
349348
logger.info(f"Generated test {count_tests}/{count_tests}:")
350349
code_print(concolic_test_str)
351350

@@ -972,7 +971,7 @@ def get_code_optimization_context(self) -> Result[CodeOptimizationContext, str]:
972971

973972
return Success(
974973
CodeOptimizationContext(
975-
testgen_context_code=new_code_ctx.testgen_context_code,
974+
testgen_context=new_code_ctx.testgen_context,
976975
read_writable_code=new_code_ctx.read_writable_code,
977976
read_only_context_code=new_code_ctx.read_only_context_code,
978977
hashing_code_context=new_code_ctx.hashing_code_context,
@@ -1079,7 +1078,7 @@ def instrument_existing_tests(self, function_to_all_tests: dict[str, set[Functio
10791078

10801079
def generate_tests_and_optimizations(
10811080
self,
1082-
testgen_context_code: str,
1081+
testgen_context: CodeStringsMarkdown,
10831082
read_writable_code: CodeStringsMarkdown,
10841083
read_only_context_code: str,
10851084
helper_functions: list[FunctionSource],
@@ -1093,7 +1092,7 @@ def generate_tests_and_optimizations(
10931092
# Submit the test generation task as future
10941093
future_tests = self.submit_test_generation_tasks(
10951094
self.executor,
1096-
testgen_context_code,
1095+
testgen_context.markdown,
10971096
[definition.fully_qualified_name for definition in helper_functions],
10981097
generated_test_paths,
10991098
generated_perf_test_paths,

tests/test_code_replacement.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -798,7 +798,8 @@ def main_method(self):
798798

799799

800800
def test_code_replacement10() -> None:
801-
get_code_output = """from __future__ import annotations
801+
get_code_output = """# file: test_code_replacement.py
802+
from __future__ import annotations
802803
803804
class HelperClass:
804805
def __init__(self, name):
@@ -828,7 +829,7 @@ def main_method(self):
828829
)
829830
func_optimizer = FunctionOptimizer(function_to_optimize=func_top_optimize, test_cfg=test_config)
830831
code_context = func_optimizer.get_code_optimization_context().unwrap()
831-
assert code_context.testgen_context_code.rstrip() == get_code_output.rstrip()
832+
assert code_context.testgen_context.flat.rstrip() == get_code_output.rstrip()
832833

833834

834835
def test_code_replacement11() -> None:

tests/test_code_utils.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
)
2222
from codeflash.code_utils.concolic_utils import clean_concolic_tests
2323
from codeflash.code_utils.coverage_utils import extract_dependent_function, generate_candidates, prepare_coverage_files
24+
from codeflash.models.models import CodeStringsMarkdown
2425

2526

2627
@pytest.fixture
@@ -382,69 +383,76 @@ def mock_code_context():
382383
def test_extract_dependent_function_sync_and_async(mock_code_context):
383384
"""Test extract_dependent_function with both sync and async functions."""
384385
# Test sync function extraction
385-
mock_code_context.testgen_context_code = """
386+
mock_code_context.testgen_context = CodeStringsMarkdown.parse_markdown_code("""```python:file.py
386387
def main_function():
387388
pass
388389
389390
def helper_function():
390391
pass
391-
"""
392+
```
393+
""")
392394
assert extract_dependent_function("main_function", mock_code_context) == "helper_function"
393395

394396
# Test async function extraction
395-
mock_code_context.testgen_context_code = """
397+
mock_code_context.testgen_context = CodeStringsMarkdown.parse_markdown_code("""```python:file.py
396398
def main_function():
397399
pass
398400
399401
async def async_helper_function():
400402
pass
401-
"""
403+
```
404+
""")
405+
402406
assert extract_dependent_function("main_function", mock_code_context) == "async_helper_function"
403407

404408

405409
def test_extract_dependent_function_edge_cases(mock_code_context):
406410
"""Test extract_dependent_function edge cases."""
407411
# No dependent functions
408-
mock_code_context.testgen_context_code = """
412+
mock_code_context.testgen_context = CodeStringsMarkdown.parse_markdown_code("""```python:file.py
409413
def main_function():
410414
pass
411-
"""
415+
```
416+
""")
412417
assert extract_dependent_function("main_function", mock_code_context) is False
413418

414419
# Multiple dependent functions
415-
mock_code_context.testgen_context_code = """
420+
mock_code_context.testgen_context = CodeStringsMarkdown.parse_markdown_code("""```python:file.py
416421
def main_function():
417422
pass
418-
419423
def helper1():
420424
pass
421425
422426
async def helper2():
423427
pass
424-
"""
428+
```
429+
""")
425430
assert extract_dependent_function("main_function", mock_code_context) is False
426431

427432

428433
def test_extract_dependent_function_mixed_scenarios(mock_code_context):
429434
"""Test extract_dependent_function with mixed sync/async scenarios."""
430435
# Async main with sync helper
431-
mock_code_context.testgen_context_code = """
436+
mock_code_context.testgen_context = CodeStringsMarkdown.parse_markdown_code("""```python:file.py
432437
async def async_main():
433438
pass
434439
435440
def sync_helper():
436441
pass
437-
"""
442+
```
443+
""")
438444
assert extract_dependent_function("async_main", mock_code_context) == "sync_helper"
439445

440446
# Only async functions
441-
mock_code_context.testgen_context_code = """
447+
mock_code_context.testgen_context = CodeStringsMarkdown.parse_markdown_code("""```python:file.py
442448
async def async_main():
443449
pass
444450
445451
async def async_helper():
446452
pass
447-
"""
453+
```
454+
""")
455+
448456
assert extract_dependent_function("async_main", mock_code_context) == "async_helper"
449457

450458

tests/test_function_dependencies.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,9 @@ def test_class_method_dependencies() -> None:
160160
)
161161
assert code_context.helper_functions[0].qualified_name == "Graph.topologicalSortUtil"
162162
assert (
163-
code_context.testgen_context_code
164-
== """from collections import defaultdict
163+
code_context.testgen_context.flat
164+
== """# file: test_function_dependencies.py
165+
from collections import defaultdict
165166
166167
class Graph:
167168
def __init__(self, vertices):
@@ -220,8 +221,9 @@ def test_recursive_function_context() -> None:
220221
assert code_context.helper_functions[0].fully_qualified_name == "test_function_dependencies.C.calculate_something_3"
221222
assert code_context.helper_functions[1].fully_qualified_name == "test_function_dependencies.C.recursive"
222223
assert (
223-
code_context.testgen_context_code
224-
== """class C:
224+
code_context.testgen_context.flat
225+
== """# file: test_function_dependencies.py
226+
class C:
225227
def calculate_something_3(self, num):
226228
return num + 1
227229

tests/test_get_helper_code.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
88
from codeflash.either import is_successful
9-
from codeflash.models.models import FunctionParent
9+
from codeflash.models.models import FunctionParent, get_code_block_splitter
1010
from codeflash.optimization.function_optimizer import FunctionOptimizer
1111
from codeflash.optimization.optimizer import Optimizer
1212
from codeflash.verification.verification_utils import TestConfig
@@ -242,8 +242,9 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
242242
code_context = ctx_result.unwrap()
243243
assert code_context.helper_functions[0].qualified_name == "AbstractCacheBackend.get_cache_or_call"
244244
assert (
245-
code_context.testgen_context_code
246-
== f'''_P = ParamSpec("_P")
245+
code_context.testgen_context.flat
246+
== f'''# file: {file_path.relative_to(project_root_path)}
247+
_P = ParamSpec("_P")
247248
_KEY_T = TypeVar("_KEY_T")
248249
_STORE_T = TypeVar("_STORE_T")
249250
class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]):
@@ -395,10 +396,11 @@ def test_bubble_sort_deps() -> None:
395396
function_to_optimize = FunctionToOptimize(
396397
function_name="sorter_deps", file_path=file_path, parents=[], starting_line=None, ending_line=None
397398
)
399+
project_root = file_path.parent.parent.resolve()
398400
test_config = TestConfig(
399401
tests_root=str(file_path.parent / "tests"),
400402
tests_project_rootdir=file_path.parent.resolve(),
401-
project_root_path=file_path.parent.parent.resolve(),
403+
project_root_path=project_root,
402404
test_framework="pytest",
403405
pytest_cmd="pytest",
404406
)
@@ -410,19 +412,20 @@ def test_bubble_sort_deps() -> None:
410412
pytest.fail()
411413
code_context = ctx_result.unwrap()
412414
assert (
413-
code_context.testgen_context_code
414-
== """from code_to_optimize.bubble_sort_dep1_helper import dep1_comparer
415-
from code_to_optimize.bubble_sort_dep2_swap import dep2_swap
416-
415+
code_context.testgen_context.flat
416+
== f"""{get_code_block_splitter(Path("code_to_optimize/bubble_sort_dep1_helper.py"))}
417417
def dep1_comparer(arr, j: int) -> bool:
418418
return arr[j] > arr[j + 1]
419419
420+
{get_code_block_splitter(Path("code_to_optimize/bubble_sort_dep2_swap.py"))}
420421
def dep2_swap(arr, j):
421422
temp = arr[j]
422423
arr[j] = arr[j + 1]
423424
arr[j + 1] = temp
424425
425-
426+
{get_code_block_splitter(Path("code_to_optimize/bubble_sort_deps.py"))}
427+
from code_to_optimize.bubble_sort_dep1_helper import dep1_comparer
428+
from code_to_optimize.bubble_sort_dep2_swap import dep2_swap
426429
427430
def sorter_deps(arr):
428431
for i in range(len(arr)):

0 commit comments

Comments
 (0)