Skip to content

Commit 5830a70

Browse files
quick and dirty
1 parent cff6a30 commit 5830a70

File tree

6 files changed

+416
-40
lines changed

6 files changed

+416
-40
lines changed

codeflash/discovery/functions_to_optimize.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -306,25 +306,43 @@ def levenshtein_distance(s1: str, s2: str) -> int:
306306
len1 = len(s1)
307307
len2 = len(s2)
308308
# Use a preallocated list instead of creating a new list every iteration
309+
310+
# Early exit for empty string cases
311+
if len1 == 0:
312+
return len2
313+
if len2 == 0:
314+
return len1
315+
316+
# Convert strings to lists for fast indexed access
317+
s1_list = list(s1)
318+
s2_list = list(s2)
319+
320+
# Preallocate and reuse arrays; avoid creating new ones every iteration
309321
previous = list(range(len1 + 1))
310322
current = [0] * (len1 + 1)
311323

312324
for index2 in range(len2):
313-
char2 = s2[index2]
325+
char2 = s2_list[index2]
314326
current[0] = index2 + 1
327+
328+
# Remove redundant intermediate assignments for better cache locality
329+
prev = previous
330+
curr = current
331+
s1_chars = s1_list
332+
# Use local variables for frequently accessed values
315333
for index1 in range(len1):
316-
char1 = s1[index1]
317-
if char1 == char2:
318-
current[index1 + 1] = previous[index1]
334+
# Unrolling char1 assignment and equality check
335+
if s1_chars[index1] == char2:
336+
curr[index1 + 1] = prev[index1]
319337
else:
320-
# Fast min calculation without tuple construct
321-
a = previous[index1]
322-
b = previous[index1 + 1]
323-
c = current[index1]
324-
min_val = min(b, a)
325-
min_val = min(c, min_val)
326-
current[index1 + 1] = 1 + min_val
327-
# Swap references instead of copying
338+
x = prev[index1]
339+
y = prev[index1 + 1]
340+
z = curr[index1]
341+
min_xy = min(x, y)
342+
min_xyz = min(z, min_xy)
343+
curr[index1 + 1] = 1 + min_xyz
344+
345+
# Swap references rather than copying data
328346
previous, current = current, previous
329347
return previous[len1]
330348

codeflash/models/models.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from collections import Counter, defaultdict
44
from typing import TYPE_CHECKING
55

6+
import libcst as cst
67
from rich.tree import Tree
78

89
from codeflash.cli_cmds.console import DEBUG_MODE, lsp_log
@@ -505,6 +506,31 @@ def id(self) -> str:
505506
f"{self.function_getting_tested}:{self.iteration_id}"
506507
)
507508

509+
def find_func_in_class(self, class_node: cst.ClassDef, func_name: str) -> Optional[cst.FunctionDef]:
510+
for stmt in class_node.body.body:
511+
if isinstance(stmt, cst.FunctionDef) and stmt.name.value == func_name:
512+
return stmt
513+
return None
514+
515+
def get_src_code(self, test_path: Path) -> Optional[str]:
516+
test_src = test_path.read_text(encoding="utf-8")
517+
module_node = cst.parse_module(test_src)
518+
519+
if self.test_class_name:
520+
for stmt in module_node.body:
521+
if isinstance(stmt, cst.ClassDef) and stmt.name.value == self.test_class_name:
522+
func_node = self.find_func_in_class(stmt, self.test_function_name)
523+
if func_node:
524+
return module_node.code_for_node(func_node).strip()
525+
# class not found
526+
return None
527+
528+
# Otherwise, look for a top level function
529+
for stmt in module_node.body:
530+
if isinstance(stmt, cst.FunctionDef) and stmt.name.value == self.test_function_name:
531+
return module_node.code_for_node(stmt).strip()
532+
return None
533+
508534
@staticmethod
509535
def from_str_id(string_id: str, iteration_id: str | None = None) -> InvocationId:
510536
components = string_id.split(":")
@@ -549,7 +575,10 @@ class TestResults(BaseModel): # noqa: PLW1641
549575
# also we don't support deletion of test results elements - caution is advised
550576
test_results: list[FunctionTestInvocation] = []
551577
test_result_idx: dict[str, int] = {}
578+
552579
perf_stdout: Optional[str] = None
580+
# mapping between test function name and stdout failure message
581+
test_failures: Optional[dict[str, str]] = None
553582

554583
def add(self, function_test_invocation: FunctionTestInvocation) -> None:
555584
unique_id = function_test_invocation.unique_invocation_loop_id

codeflash/optimization/function_optimizer.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1752,6 +1752,11 @@ def establish_original_code_baseline(
17521752
)
17531753
)
17541754

1755+
def get_results_not_matched_error(self) -> Failure:
1756+
logger.info("h4|Test results did not match the test results of the original code ❌")
1757+
console.rule()
1758+
return Failure("Test results did not match the test results of the original code.")
1759+
17551760
def run_optimized_candidate(
17561761
self,
17571762
*,
@@ -1808,13 +1813,25 @@ def run_optimized_candidate(
18081813
)
18091814
)
18101815
console.rule()
1811-
if compare_test_results(baseline_results.behavior_test_results, candidate_behavior_results):
1816+
match, diffs = compare_test_results(baseline_results.behavior_test_results, candidate_behavior_results)
1817+
if match:
18121818
logger.info("h3|Test results matched ✅")
18131819
console.rule()
18141820
else:
1815-
logger.info("h4|Test results did not match the test results of the original code ❌")
1816-
console.rule()
1817-
return Failure("Test results did not match the test results of the original code.")
1821+
result_unmatched_perc = len(diffs) / len(candidate_behavior_results)
1822+
if result_unmatched_perc > 0.5:
1823+
# if the test unmatched percentage is greater than 50%, we can't fix it
1824+
return self.get_results_not_matched_error()
1825+
1826+
# with the parsed test results diff ask the llm to fix the candidate to match the test results of the original code, and run again
1827+
# self.run_optimized_candidate(
1828+
# optimization_candidate_index=optimization_candidate_index,
1829+
# baseline_results=baseline_results,
1830+
# original_helper_code=original_helper_code,
1831+
# file_path_to_helper_classes=file_path_to_helper_classes,
1832+
# )
1833+
print(f"should try to fix it, diffs: {diffs}")
1834+
return self.get_results_not_matched_error()
18181835

18191836
logger.info(f"loading|Running performance tests for candidate {optimization_candidate_index}...")
18201837

Lines changed: 61 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import sys
2+
from dataclasses import dataclass
3+
from enum import Enum
24

35
from codeflash.cli_cmds.console import logger
46
from codeflash.models.models import TestResults, TestType, VerificationType
@@ -7,21 +9,38 @@
79
INCREASED_RECURSION_LIMIT = 5000
810

911

10-
def compare_test_results(original_results: TestResults, candidate_results: TestResults) -> bool:
12+
class TestDiffScope(Enum):
13+
RETURN_VALUE = "return_value"
14+
STDOUT = "stdout"
15+
TIMED_OUT = "timed_out"
16+
DID_PASS = "did_pass" # noqa: S105
17+
18+
19+
@dataclass
20+
class TestDiff:
21+
scope: TestDiffScope
22+
test_src_code: str
23+
pytest_error: str
24+
original_value: any
25+
candidate_value: any
26+
27+
28+
def compare_test_results(original_results: TestResults, candidate_results: TestResults) -> tuple[bool, list[TestDiff]]:
1129
# This is meant to be only called with test results for the first loop index
1230
if len(original_results) == 0 or len(candidate_results) == 0:
13-
return False # empty test results are not equal
31+
return False, [] # empty test results are not equal
1432
original_recursion_limit = sys.getrecursionlimit()
1533
if original_recursion_limit < INCREASED_RECURSION_LIMIT:
1634
sys.setrecursionlimit(INCREASED_RECURSION_LIMIT) # Increase recursion limit to avoid RecursionError
1735
test_ids_superset = original_results.get_all_unique_invocation_loop_ids().union(
1836
set(candidate_results.get_all_unique_invocation_loop_ids())
1937
)
20-
are_equal: bool = True
38+
test_diffs: list[TestDiff] = []
2139
did_all_timeout: bool = True
2240
for test_id in test_ids_superset:
2341
original_test_result = original_results.get_by_unique_invocation_loop_id(test_id)
2442
cdd_test_result = candidate_results.get_by_unique_invocation_loop_id(test_id)
43+
candidate_pytest_error = candidate_results.test_failures.get(original_test_result.id.test_function_name)
2544
if cdd_test_result is not None and original_test_result is None:
2645
continue
2746
# If helper function instance_state verification is not present, that's ok. continue
@@ -32,8 +51,7 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR
3251
):
3352
continue
3453
if original_test_result is None or cdd_test_result is None:
35-
are_equal = False
36-
break
54+
return False, []
3755
did_all_timeout = did_all_timeout and original_test_result.timed_out
3856
if original_test_result.timed_out:
3957
continue
@@ -43,31 +61,42 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR
4361
in {VerificationType.INIT_STATE_HELPER, VerificationType.INIT_STATE_FTO}
4462
):
4563
superset_obj = True
64+
test_src_code = original_test_result.id.get_src_code(original_test_result.file_name)
4665
if not comparator(original_test_result.return_value, cdd_test_result.return_value, superset_obj=superset_obj):
47-
are_equal = False
66+
test_diffs.append(
67+
TestDiff(
68+
scope=TestDiffScope.RETURN_VALUE,
69+
test_src_code=test_src_code,
70+
original_value=original_test_result.return_value,
71+
candidate_value=cdd_test_result.return_value,
72+
pytest_error=candidate_pytest_error,
73+
)
74+
)
75+
4876
try:
49-
logger.debug(
50-
"File Name: %s\n"
51-
"Test Type: %s\n"
52-
"Verification Type: %s\n"
53-
"Invocation ID: %s\n"
54-
"Original return value: %s\n"
55-
"Candidate return value: %s\n"
56-
"-------------------",
57-
original_test_result.file_name,
58-
original_test_result.test_type,
59-
original_test_result.verification_type,
60-
original_test_result.id,
61-
original_test_result.return_value,
62-
cdd_test_result.return_value,
77+
print(
78+
f"File Name: {original_test_result.file_name}\n"
79+
f"Test Type: {original_test_result.test_type}\n"
80+
f"Verification Type: {original_test_result.verification_type}\n"
81+
f"Invocation ID: {original_test_result.id}\n"
82+
f"Original return value: {original_test_result.return_value}\n"
83+
f"Candidate return value: {cdd_test_result.return_value}\n"
6384
)
6485
except Exception as e:
6586
logger.error(e)
6687
break
6788
if (original_test_result.stdout and cdd_test_result.stdout) and not comparator(
6889
original_test_result.stdout, cdd_test_result.stdout
6990
):
70-
are_equal = False
91+
test_diffs.append(
92+
TestDiff(
93+
scope=TestDiffScope.STDOUT,
94+
test_src_code=test_src_code,
95+
original_value=original_test_result.stdout,
96+
candidate_value=cdd_test_result.stdout,
97+
pytest_error=candidate_pytest_error,
98+
)
99+
)
71100
break
72101

73102
if original_test_result.test_type in {
@@ -76,9 +105,17 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR
76105
TestType.GENERATED_REGRESSION,
77106
TestType.REPLAY_TEST,
78107
} and (cdd_test_result.did_pass != original_test_result.did_pass):
79-
are_equal = False
108+
test_diffs.append(
109+
TestDiff(
110+
scope=TestDiffScope.DID_PASS,
111+
test_src_code=test_src_code,
112+
original_value=original_test_result.did_pass,
113+
candidate_value=cdd_test_result.did_pass,
114+
pytest_error=candidate_pytest_error,
115+
)
116+
)
80117
break
81118
sys.setrecursionlimit(original_recursion_limit)
82119
if did_all_timeout:
83-
return False
84-
return are_equal
120+
return False, test_diffs
121+
return len(test_diffs) == 0, test_diffs

codeflash/verification/parse_test_output.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,43 @@ def merge_test_results(
512512
return merged_test_results
513513

514514

515+
def parse_test_failures_from_stdout(test_results: TestResults, stdout: str) -> TestResults:
516+
stdout_lines = stdout.splitlines()
517+
start_line = -1
518+
end_line = -1
519+
for i, line in enumerate(stdout_lines):
520+
if start_line != -1 and end_line != -1:
521+
break
522+
if "FAILURES" in line:
523+
start_line = i
524+
elif "short test summary info" in line:
525+
end_line = i
526+
if start_line == -1 or end_line == -1:
527+
return test_results
528+
529+
complete_failure_output_lines = stdout_lines[start_line:end_line] # exclude last summary line
530+
531+
test_case_to_failure: dict[str, str] = {}
532+
533+
current_test_case: str | None = None
534+
current_failure_lines: list[str] = []
535+
536+
for line in complete_failure_output_lines:
537+
if line.startswith("_______"):
538+
if current_test_case:
539+
test_case_to_failure[current_test_case] = "".join(current_failure_lines)
540+
current_test_case = line.strip("_ ").strip()
541+
current_failure_lines = []
542+
elif current_test_case:
543+
current_failure_lines.append(line + "\n")
544+
545+
if current_test_case:
546+
test_case_to_failure[current_test_case] = "".join(current_failure_lines)
547+
548+
test_results.test_failures = test_case_to_failure
549+
return test_results
550+
551+
515552
def parse_test_results(
516553
test_xml_path: Path,
517554
test_files: TestFiles,
@@ -572,4 +609,9 @@ def parse_test_results(
572609
function_name=function_name,
573610
)
574611
coverage.log_coverage()
612+
try:
613+
parse_test_failures_from_stdout(results, run_result.stdout)
614+
except Exception as e:
615+
logger.exception(e)
616+
575617
return results, coverage if all_args else None

0 commit comments

Comments
 (0)