Skip to content

Commit 99f0954

Browse files
committed
lazy impl
1 parent 7ee1ab1 commit 99f0954

File tree

3 files changed

+156
-4
lines changed

3 files changed

+156
-4
lines changed

codeflash/discovery/discover_unit_tests.py

Lines changed: 149 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import re
1010
import sqlite3
1111
import subprocess
12+
import sys
1213
import unittest
1314
from collections import defaultdict
1415
from pathlib import Path
@@ -66,6 +67,75 @@ class TestFunction:
6667
FUNCTION_NAME_REGEX = re.compile(r"([^.]+)\.([a-zA-Z0-9_]+)$")
6768

6869

70+
def _extract_dotted_call_name(node: ast.expr) -> str | None:
71+
"""Extract full dotted name from function call (e.g., 'src.math.computation.gcd_recursive')."""
72+
parts = []
73+
current = node
74+
while isinstance(current, ast.Attribute):
75+
parts.insert(0, current.attr)
76+
current = current.value
77+
if isinstance(current, ast.Name):
78+
parts.insert(0, current.id)
79+
return ".".join(parts) if parts else None
80+
return None
81+
82+
83+
def _discover_calls_via_ast(
84+
test_file: Path, test_functions: set[TestFunction], target_qualified_names: set[str]
85+
) -> dict[str, list[tuple[TestFunction, CodePosition]]]:
86+
try:
87+
with test_file.open("r", encoding="utf-8") as f:
88+
source = f.read()
89+
tree = ast.parse(source, filename=str(test_file))
90+
except (SyntaxError, FileNotFoundError) as e:
91+
logger.debug(f"AST parsing failed for {test_file}: {e}")
92+
return {}
93+
94+
import_map = {} # alias -> full_qualified_path
95+
for node in ast.walk(tree):
96+
if isinstance(node, ast.Import):
97+
for alias in node.names:
98+
name = alias.asname or alias.name
99+
import_map[name] = alias.name
100+
elif isinstance(node, ast.ImportFrom) and node.module:
101+
for alias in node.names:
102+
if alias.name != "*":
103+
full_name = f"{node.module}.{alias.name}"
104+
name = alias.asname or alias.name
105+
import_map[name] = full_name
106+
107+
test_funcs_by_name = {tf.function_name: tf for tf in test_functions}
108+
109+
result = defaultdict(list)
110+
111+
for node in ast.walk(tree):
112+
if not isinstance(node, ast.FunctionDef) or node.name not in test_funcs_by_name:
113+
continue
114+
115+
test_func = test_funcs_by_name[node.name]
116+
117+
for child in ast.walk(node):
118+
if not isinstance(child, ast.Call):
119+
continue
120+
121+
call_name = _extract_dotted_call_name(child.func)
122+
if not call_name:
123+
continue
124+
125+
if call_name in target_qualified_names:
126+
result[call_name].append((test_func, CodePosition(line_no=child.lineno, col_no=child.col_offset)))
127+
continue
128+
129+
parts = call_name.split(".", 1)
130+
if parts[0] in import_map:
131+
resolved = f"{import_map[parts[0]]}.{parts[1]}" if len(parts) == 2 else import_map[parts[0]]
132+
133+
if resolved in target_qualified_names:
134+
result[resolved].append((test_func, CodePosition(line_no=child.lineno, col_no=child.col_offset)))
135+
136+
return dict(result)
137+
138+
69139
class TestsCache:
70140
SCHEMA_VERSION = 1 # Increment this when schema changes
71141

@@ -489,6 +559,7 @@ def discover_tests_pytest(
489559
console.rule()
490560
else:
491561
logger.debug(f"Pytest collection exit code: {exitcode}")
562+
492563
if pytest_rootdir is not None:
493564
cfg.tests_project_rootdir = Path(pytest_rootdir)
494565
file_to_test_map: dict[Path, list[FunctionCalledInTest]] = defaultdict(list)
@@ -511,6 +582,7 @@ def discover_tests_pytest(
511582
if discover_only_these_tests and test_obj.test_file not in discover_only_these_tests:
512583
continue
513584
file_to_test_map[test_obj.test_file].append(test_obj)
585+
514586
# Within these test files, find the project functions they are referring to and return their names/locations
515587
return process_test_files(file_to_test_map, cfg, functions_to_optimize)
516588

@@ -592,7 +664,9 @@ def process_test_files(
592664
test_framework = cfg.test_framework
593665

594666
if functions_to_optimize:
595-
target_function_names = {func.qualified_name for func in functions_to_optimize}
667+
target_function_names = {
668+
func.qualified_name_with_modules_from_root(project_root_path) for func in functions_to_optimize
669+
}
596670
file_to_test_map = filter_test_files_by_imports(file_to_test_map, target_function_names)
597671

598672
function_to_test_map = defaultdict(set)
@@ -602,6 +676,7 @@ def process_test_files(
602676

603677
tests_cache = TestsCache(project_root_path)
604678
logger.info("!lsp|Discovering tests and processing unit tests")
679+
605680
with test_files_progress_bar(total=len(file_to_test_map), description="Processing test files") as (
606681
progress,
607682
task_id,
@@ -702,6 +777,79 @@ def process_test_files(
702777
test_functions_by_name[func.function_name].append(func)
703778

704779
test_function_names_set = set(test_functions_by_name.keys())
780+
781+
is_generated_test_file = (
782+
any(
783+
tf.test_type in (TestType.HYPOTHESIS_TEST, TestType.CONCOLIC_COVERAGE_TEST) for tf in test_functions
784+
)
785+
if test_functions
786+
else any(
787+
func.test_type in (TestType.HYPOTHESIS_TEST, TestType.CONCOLIC_COVERAGE_TEST) for func in functions
788+
)
789+
)
790+
791+
# For generated tests, use AST-based discovery since Jedi often fails
792+
if is_generated_test_file and functions_to_optimize:
793+
logger.debug(f"Using AST-based discovery for generated test file: {test_file.name}")
794+
target_qualified_names = {
795+
func.qualified_name_with_modules_from_root(project_root_path) for func in functions_to_optimize
796+
}
797+
798+
if not test_functions:
799+
logger.debug("Jedi found no functions, building test_functions from collected functions")
800+
test_functions = {
801+
TestFunction(
802+
function_name=func.test_function,
803+
test_class=func.test_class,
804+
parameters=None,
805+
test_type=func.test_type,
806+
)
807+
for func in functions
808+
}
809+
810+
ast_results = _discover_calls_via_ast(test_file, test_functions, target_qualified_names)
811+
812+
for qualified_name, matches in ast_results.items():
813+
for test_func, position in matches:
814+
if test_func.parameters is not None:
815+
if test_framework == "pytest":
816+
scope_test_function = f"{test_func.function_name}[{test_func.parameters}]"
817+
else: # unittest
818+
scope_test_function = f"{test_func.function_name}_{test_func.parameters}"
819+
else:
820+
scope_test_function = test_func.function_name
821+
822+
function_to_test_map[qualified_name].add(
823+
FunctionCalledInTest(
824+
tests_in_file=TestsInFile(
825+
test_file=test_file,
826+
test_class=test_func.test_class,
827+
test_function=scope_test_function,
828+
test_type=test_func.test_type,
829+
),
830+
position=position,
831+
)
832+
)
833+
tests_cache.insert_test(
834+
file_path=str(test_file),
835+
file_hash=file_hash,
836+
qualified_name_with_modules_from_root=qualified_name,
837+
function_name=test_func.function_name,
838+
test_class=test_func.test_class or "",
839+
test_function=scope_test_function,
840+
test_type=test_func.test_type,
841+
line_number=position.line_no,
842+
col_number=position.col_no,
843+
)
844+
845+
if test_func.test_type == TestType.REPLAY_TEST:
846+
num_discovered_replay_tests += 1
847+
848+
num_discovered_tests += 1
849+
850+
progress.advance(task_id)
851+
continue
852+
705853
relevant_names = []
706854

707855
names_with_full_name = [name for name in all_names if name.full_name is not None]

codeflash/verification/concolic_testing.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,10 @@ def generate_concolic_tests(
8080
test_framework=args.test_framework,
8181
pytest_cmd=args.pytest_cmd,
8282
)
83-
function_to_concolic_tests, num_discovered_concolic_tests, _ = discover_unit_tests(concolic_test_cfg)
83+
file_to_funcs = {function_to_optimize.file_path: [function_to_optimize]}
84+
function_to_concolic_tests, num_discovered_concolic_tests, _ = discover_unit_tests(
85+
concolic_test_cfg, file_to_funcs_to_optimize=file_to_funcs
86+
)
8487
logger.info(
8588
f"Created {num_discovered_concolic_tests} "
8689
f"concolic unit test case{'s' if num_discovered_concolic_tests != 1 else ''} "

codeflash/verification/hypothesis_testing.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,8 +268,9 @@ def generate_hypothesis_tests(
268268
test_framework=args.test_framework,
269269
pytest_cmd=args.pytest_cmd,
270270
)
271+
file_to_funcs = {function_to_optimize.file_path: [function_to_optimize]}
271272
function_to_hypothesis_tests, num_discovered_hypothesis_tests, _ = (
272-
discover_unit_tests(hypothesis_config)
273+
discover_unit_tests(hypothesis_config, file_to_funcs_to_optimize=file_to_funcs)
273274
)
274275
with hypothesis_path.open("r", encoding="utf-8") as f:
275276
original_code = f.read()
@@ -290,7 +291,7 @@ def generate_hypothesis_tests(
290291
with hypothesis_path.open("w", encoding="utf-8") as f:
291292
f.write(hypothesis_test_suite_code)
292293
function_to_hypothesis_tests, num_discovered_hypothesis_tests, _ = (
293-
discover_unit_tests(hypothesis_config)
294+
discover_unit_tests(hypothesis_config, file_to_funcs_to_optimize=file_to_funcs)
294295
)
295296
logger.info(
296297
f"Created {num_discovered_hypothesis_tests} "

0 commit comments

Comments
 (0)