Skip to content

Commit a9cebf1

Browse files
authored
Merge branch 'main' into worktree/mirror-all-arg
2 parents a3907d8 + 63466b1 commit a9cebf1

File tree

9 files changed

+91
-70
lines changed

9 files changed

+91
-70
lines changed

codeflash/code_utils/coverage_utils.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,12 @@
1212

1313
def extract_dependent_function(main_function: str, code_context: CodeOptimizationContext) -> str | Literal[False]:
1414
"""Extract the single dependent function from the code context excluding the main function."""
15-
ast_tree = ast.parse(code_context.testgen_context_code)
16-
17-
dependent_functions = {
18-
node.name for node in ast_tree.body if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef))
19-
}
15+
dependent_functions = set()
16+
for code_string in code_context.testgen_context.code_strings:
17+
ast_tree = ast.parse(code_string.code)
18+
dependent_functions.update(
19+
{node.name for node in ast_tree.body if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef))}
20+
)
2021

2122
if main_function in dependent_functions:
2223
dependent_functions.discard(main_function)

codeflash/context/code_context_extractor.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -114,32 +114,32 @@ def get_code_optimization_context(
114114
read_only_context_code = ""
115115

116116
# Extract code context for testgen
117-
testgen_code_markdown = extract_code_string_context_from_files(
117+
testgen_context = extract_code_markdown_context_from_files(
118118
helpers_of_fto_dict,
119119
helpers_of_helpers_dict,
120120
project_root_path,
121121
remove_docstrings=False,
122122
code_context_type=CodeContextType.TESTGEN,
123123
)
124-
testgen_context_code = testgen_code_markdown.code
125-
testgen_context_code_tokens = encoded_tokens_len(testgen_context_code)
126-
if testgen_context_code_tokens > testgen_token_limit:
127-
testgen_code_markdown = extract_code_string_context_from_files(
124+
testgen_markdown_code = testgen_context.markdown
125+
testgen_code_token_length = encoded_tokens_len(testgen_markdown_code)
126+
if testgen_code_token_length > testgen_token_limit:
127+
testgen_context = extract_code_markdown_context_from_files(
128128
helpers_of_fto_dict,
129129
helpers_of_helpers_dict,
130130
project_root_path,
131131
remove_docstrings=True,
132132
code_context_type=CodeContextType.TESTGEN,
133133
)
134-
testgen_context_code = testgen_code_markdown.code
135-
testgen_context_code_tokens = encoded_tokens_len(testgen_context_code)
136-
if testgen_context_code_tokens > testgen_token_limit:
134+
testgen_markdown_code = testgen_context.markdown
135+
testgen_code_token_length = encoded_tokens_len(testgen_markdown_code)
136+
if testgen_code_token_length > testgen_token_limit:
137137
raise ValueError("Testgen code context has exceeded token limit, cannot proceed")
138138
code_hash_context = hashing_code_context.markdown
139139
code_hash = hashlib.sha256(code_hash_context.encode("utf-8")).hexdigest()
140140

141141
return CodeOptimizationContext(
142-
testgen_context_code=testgen_context_code,
142+
testgen_context=testgen_context,
143143
read_writable_code=final_read_writable_code,
144144
read_only_context_code=read_only_context_code,
145145
hashing_code_context=code_hash_context,
Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
# ruff: noqa
22
import sys
3+
from pathlib import Path
34
from typing import Any
5+
import pickle
6+
47

58
# This script should not have any relation to the codeflash package, be careful with imports
69
cwd = sys.argv[1]
@@ -11,44 +14,48 @@
1114
sys.path.insert(1, str(cwd))
1215

1316

17+
def parse_pytest_collection_results(pytest_tests: list[Any]) -> list[dict[str, str]]:
18+
test_results = []
19+
for test in pytest_tests:
20+
test_class = None
21+
if test.cls:
22+
test_class = test.parent.name
23+
test_results.append({"test_file": str(test.path), "test_class": test_class, "test_function": test.name})
24+
return test_results
25+
26+
1427
class PytestCollectionPlugin:
1528
def pytest_collection_finish(self, session) -> None:
16-
global pytest_rootdir
29+
global pytest_rootdir, collected_tests
30+
1731
collected_tests.extend(session.items)
1832
pytest_rootdir = session.config.rootdir
1933

34+
# Write results immediately since pytest.main() will exit after this callback, not always with a success code
35+
tests = parse_pytest_collection_results(collected_tests)
36+
exit_code = getattr(session.config, "exitstatus", 0)
37+
with Path(pickle_path).open("wb") as f:
38+
pickle.dump((exit_code, tests, pytest_rootdir), f, protocol=pickle.HIGHEST_PROTOCOL)
39+
2040
def pytest_collection_modifyitems(self, items) -> None:
2141
skip_benchmark = pytest.mark.skip(reason="Skipping benchmark tests")
2242
for item in items:
2343
if "benchmark" in item.fixturenames:
2444
item.add_marker(skip_benchmark)
2545

2646

27-
def parse_pytest_collection_results(pytest_tests: list[Any]) -> list[dict[str, str]]:
28-
test_results = []
29-
for test in pytest_tests:
30-
test_class = None
31-
if test.cls:
32-
test_class = test.parent.name
33-
test_results.append({"test_file": str(test.path), "test_class": test_class, "test_function": test.name})
34-
return test_results
35-
36-
3747
if __name__ == "__main__":
38-
from pathlib import Path
39-
4048
import pytest
4149

4250
try:
43-
exitcode = pytest.main(
44-
[tests_root, "-p no:logging", "--collect-only", "-m", "not skip", "-p", "no:codeflash-benchmark"],
51+
pytest.main(
52+
[tests_root, "-p", "no:logging", "--collect-only", "-m", "not skip", "-p", "no:codeflash-benchmark"],
4553
plugins=[PytestCollectionPlugin()],
4654
)
4755
except Exception as e:
4856
print(f"Failed to collect tests: {e!s}")
49-
exitcode = -1
50-
tests = parse_pytest_collection_results(collected_tests)
51-
import pickle
52-
53-
with Path(pickle_path).open("wb") as f:
54-
pickle.dump((exitcode, tests, pytest_rootdir), f, protocol=pickle.HIGHEST_PROTOCOL)
57+
try:
58+
with Path(pickle_path).open("wb") as f:
59+
pickle.dump((-1, [], None), f, protocol=pickle.HIGHEST_PROTOCOL)
60+
except Exception as pickle_error:
61+
print(f"Failed to write failure pickle: {pickle_error!s}", file=sys.stderr)

codeflash/models/models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ class CodeString(BaseModel):
163163

164164

165165
def get_code_block_splitter(file_path: Path) -> str:
166-
return f"# file: {file_path}"
166+
return f"# file: {file_path.as_posix()}"
167167

168168

169169
markdown_pattern = re.compile(r"```python:([^\n]+)\n(.*?)\n```", re.DOTALL)
@@ -254,7 +254,7 @@ def parse_markdown_code(markdown_code: str) -> CodeStringsMarkdown:
254254

255255

256256
class CodeOptimizationContext(BaseModel):
257-
testgen_context_code: str = ""
257+
testgen_context: CodeStringsMarkdown
258258
read_writable_code: CodeStringsMarkdown
259259
read_only_context_code: str = ""
260260
hashing_code_context: str = ""

codeflash/optimization/function_optimizer.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ def generate_and_instrument_tests(
309309
revert_to_print=bool(get_pr_number()),
310310
):
311311
generated_results = self.generate_tests_and_optimizations(
312-
testgen_context_code=code_context.testgen_context_code,
312+
testgen_context=code_context.testgen_context,
313313
read_writable_code=code_context.read_writable_code,
314314
read_only_context_code=code_context.read_only_context_code,
315315
helper_functions=code_context.helper_functions,
@@ -345,7 +345,6 @@ def generate_and_instrument_tests(
345345
logger.info(f"Generated test {i + 1}/{count_tests}:")
346346
code_print(generated_test.generated_original_test_source, file_name=f"test_{i + 1}.py")
347347
if concolic_test_str:
348-
# no concolic tests in lsp mode
349348
logger.info(f"Generated test {count_tests}/{count_tests}:")
350349
code_print(concolic_test_str)
351350

@@ -972,7 +971,7 @@ def get_code_optimization_context(self) -> Result[CodeOptimizationContext, str]:
972971

973972
return Success(
974973
CodeOptimizationContext(
975-
testgen_context_code=new_code_ctx.testgen_context_code,
974+
testgen_context=new_code_ctx.testgen_context,
976975
read_writable_code=new_code_ctx.read_writable_code,
977976
read_only_context_code=new_code_ctx.read_only_context_code,
978977
hashing_code_context=new_code_ctx.hashing_code_context,
@@ -1079,7 +1078,7 @@ def instrument_existing_tests(self, function_to_all_tests: dict[str, set[Functio
10791078

10801079
def generate_tests_and_optimizations(
10811080
self,
1082-
testgen_context_code: str,
1081+
testgen_context: CodeStringsMarkdown,
10831082
read_writable_code: CodeStringsMarkdown,
10841083
read_only_context_code: str,
10851084
helper_functions: list[FunctionSource],
@@ -1093,7 +1092,7 @@ def generate_tests_and_optimizations(
10931092
# Submit the test generation task as future
10941093
future_tests = self.submit_test_generation_tasks(
10951094
self.executor,
1096-
testgen_context_code,
1095+
testgen_context.markdown,
10971096
[definition.fully_qualified_name for definition in helper_functions],
10981097
generated_test_paths,
10991098
generated_perf_test_paths,

tests/test_code_replacement.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -798,7 +798,8 @@ def main_method(self):
798798

799799

800800
def test_code_replacement10() -> None:
801-
get_code_output = """from __future__ import annotations
801+
get_code_output = """# file: test_code_replacement.py
802+
from __future__ import annotations
802803
803804
class HelperClass:
804805
def __init__(self, name):
@@ -828,7 +829,7 @@ def main_method(self):
828829
)
829830
func_optimizer = FunctionOptimizer(function_to_optimize=func_top_optimize, test_cfg=test_config)
830831
code_context = func_optimizer.get_code_optimization_context().unwrap()
831-
assert code_context.testgen_context_code.rstrip() == get_code_output.rstrip()
832+
assert code_context.testgen_context.flat.rstrip() == get_code_output.rstrip()
832833

833834

834835
def test_code_replacement11() -> None:

tests/test_code_utils.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
)
2222
from codeflash.code_utils.concolic_utils import clean_concolic_tests
2323
from codeflash.code_utils.coverage_utils import extract_dependent_function, generate_candidates, prepare_coverage_files
24+
from codeflash.models.models import CodeStringsMarkdown
2425

2526

2627
@pytest.fixture
@@ -382,69 +383,76 @@ def mock_code_context():
382383
def test_extract_dependent_function_sync_and_async(mock_code_context):
383384
"""Test extract_dependent_function with both sync and async functions."""
384385
# Test sync function extraction
385-
mock_code_context.testgen_context_code = """
386+
mock_code_context.testgen_context = CodeStringsMarkdown.parse_markdown_code("""```python:file.py
386387
def main_function():
387388
pass
388389
389390
def helper_function():
390391
pass
391-
"""
392+
```
393+
""")
392394
assert extract_dependent_function("main_function", mock_code_context) == "helper_function"
393395

394396
# Test async function extraction
395-
mock_code_context.testgen_context_code = """
397+
mock_code_context.testgen_context = CodeStringsMarkdown.parse_markdown_code("""```python:file.py
396398
def main_function():
397399
pass
398400
399401
async def async_helper_function():
400402
pass
401-
"""
403+
```
404+
""")
405+
402406
assert extract_dependent_function("main_function", mock_code_context) == "async_helper_function"
403407

404408

405409
def test_extract_dependent_function_edge_cases(mock_code_context):
406410
"""Test extract_dependent_function edge cases."""
407411
# No dependent functions
408-
mock_code_context.testgen_context_code = """
412+
mock_code_context.testgen_context = CodeStringsMarkdown.parse_markdown_code("""```python:file.py
409413
def main_function():
410414
pass
411-
"""
415+
```
416+
""")
412417
assert extract_dependent_function("main_function", mock_code_context) is False
413418

414419
# Multiple dependent functions
415-
mock_code_context.testgen_context_code = """
420+
mock_code_context.testgen_context = CodeStringsMarkdown.parse_markdown_code("""```python:file.py
416421
def main_function():
417422
pass
418-
419423
def helper1():
420424
pass
421425
422426
async def helper2():
423427
pass
424-
"""
428+
```
429+
""")
425430
assert extract_dependent_function("main_function", mock_code_context) is False
426431

427432

428433
def test_extract_dependent_function_mixed_scenarios(mock_code_context):
429434
"""Test extract_dependent_function with mixed sync/async scenarios."""
430435
# Async main with sync helper
431-
mock_code_context.testgen_context_code = """
436+
mock_code_context.testgen_context = CodeStringsMarkdown.parse_markdown_code("""```python:file.py
432437
async def async_main():
433438
pass
434439
435440
def sync_helper():
436441
pass
437-
"""
442+
```
443+
""")
438444
assert extract_dependent_function("async_main", mock_code_context) == "sync_helper"
439445

440446
# Only async functions
441-
mock_code_context.testgen_context_code = """
447+
mock_code_context.testgen_context = CodeStringsMarkdown.parse_markdown_code("""```python:file.py
442448
async def async_main():
443449
pass
444450
445451
async def async_helper():
446452
pass
447-
"""
453+
```
454+
""")
455+
448456
assert extract_dependent_function("async_main", mock_code_context) == "async_helper"
449457

450458

tests/test_function_dependencies.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,9 @@ def test_class_method_dependencies() -> None:
160160
)
161161
assert code_context.helper_functions[0].qualified_name == "Graph.topologicalSortUtil"
162162
assert (
163-
code_context.testgen_context_code
164-
== """from collections import defaultdict
163+
code_context.testgen_context.flat
164+
== """# file: test_function_dependencies.py
165+
from collections import defaultdict
165166
166167
class Graph:
167168
def __init__(self, vertices):
@@ -220,8 +221,9 @@ def test_recursive_function_context() -> None:
220221
assert code_context.helper_functions[0].fully_qualified_name == "test_function_dependencies.C.calculate_something_3"
221222
assert code_context.helper_functions[1].fully_qualified_name == "test_function_dependencies.C.recursive"
222223
assert (
223-
code_context.testgen_context_code
224-
== """class C:
224+
code_context.testgen_context.flat
225+
== """# file: test_function_dependencies.py
226+
class C:
225227
def calculate_something_3(self, num):
226228
return num + 1
227229

0 commit comments

Comments
 (0)