From 5830a70c3c29d4e57eea55a4f8ae536c3c1b4f12 Mon Sep 17 00:00:00 2001 From: ali Date: Thu, 27 Nov 2025 16:18:55 +0200 Subject: [PATCH 01/22] quick and dirty --- codeflash/discovery/functions_to_optimize.py | 42 +++- codeflash/models/models.py | 29 +++ codeflash/optimization/function_optimizer.py | 25 +- codeflash/verification/equivalence.py | 85 +++++-- codeflash/verification/parse_test_output.py | 42 ++++ tests/test_codeflash_capture.py | 233 +++++++++++++++++++ 6 files changed, 416 insertions(+), 40 deletions(-) diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index b8cf895e1..3bee5fbf9 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -306,25 +306,43 @@ def levenshtein_distance(s1: str, s2: str) -> int: len1 = len(s1) len2 = len(s2) # Use a preallocated list instead of creating a new list every iteration + + # Early exit for empty string cases + if len1 == 0: + return len2 + if len2 == 0: + return len1 + + # Convert strings to lists for fast indexed access + s1_list = list(s1) + s2_list = list(s2) + + # Preallocate and reuse arrays; avoid creating new ones every iteration previous = list(range(len1 + 1)) current = [0] * (len1 + 1) for index2 in range(len2): - char2 = s2[index2] + char2 = s2_list[index2] current[0] = index2 + 1 + + # Remove redundant intermediate assignments for better cache locality + prev = previous + curr = current + s1_chars = s1_list + # Use local variables for frequently accessed values for index1 in range(len1): - char1 = s1[index1] - if char1 == char2: - current[index1 + 1] = previous[index1] + # Unrolling char1 assignment and equality check + if s1_chars[index1] == char2: + curr[index1 + 1] = prev[index1] else: - # Fast min calculation without tuple construct - a = previous[index1] - b = previous[index1 + 1] - c = current[index1] - min_val = min(b, a) - min_val = min(c, min_val) - current[index1 + 1] = 1 + min_val - # Swap references instead of copying + x = prev[index1] + y = prev[index1 + 1] + z = curr[index1] + min_xy = min(x, y) + min_xyz = min(z, min_xy) + curr[index1 + 1] = 1 + min_xyz + + # Swap references rather than copying data previous, current = current, previous return previous[len1] diff --git a/codeflash/models/models.py b/codeflash/models/models.py index 744f76087..647aa2a3d 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -3,6 +3,7 @@ from collections import Counter, defaultdict from typing import TYPE_CHECKING +import libcst as cst from rich.tree import Tree from codeflash.cli_cmds.console import DEBUG_MODE, lsp_log @@ -505,6 +506,31 @@ def id(self) -> str: f"{self.function_getting_tested}:{self.iteration_id}" ) + def find_func_in_class(self, class_node: cst.ClassDef, func_name: str) -> Optional[cst.FunctionDef]: + for stmt in class_node.body.body: + if isinstance(stmt, cst.FunctionDef) and stmt.name.value == func_name: + return stmt + return None + + def get_src_code(self, test_path: Path) -> Optional[str]: + test_src = test_path.read_text(encoding="utf-8") + module_node = cst.parse_module(test_src) + + if self.test_class_name: + for stmt in module_node.body: + if isinstance(stmt, cst.ClassDef) and stmt.name.value == self.test_class_name: + func_node = self.find_func_in_class(stmt, self.test_function_name) + if func_node: + return module_node.code_for_node(func_node).strip() + # class not found + return None + + # Otherwise, look for a top level function + for stmt in module_node.body: + if isinstance(stmt, cst.FunctionDef) and stmt.name.value == self.test_function_name: + return module_node.code_for_node(stmt).strip() + return None + @staticmethod def from_str_id(string_id: str, iteration_id: str | None = None) -> InvocationId: components = string_id.split(":") @@ -549,7 +575,10 @@ class TestResults(BaseModel): # noqa: PLW1641 # also we don't support deletion of test results elements - caution is advised test_results: list[FunctionTestInvocation] = [] test_result_idx: dict[str, int] = {} + perf_stdout: Optional[str] = None + # mapping between test function name and stdout failure message + test_failures: Optional[dict[str, str]] = None def add(self, function_test_invocation: FunctionTestInvocation) -> None: unique_id = function_test_invocation.unique_invocation_loop_id diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 2eef51f0f..d948da052 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -1752,6 +1752,11 @@ def establish_original_code_baseline( ) ) + def get_results_not_matched_error(self) -> Failure: + logger.info("h4|Test results did not match the test results of the original code ❌") + console.rule() + return Failure("Test results did not match the test results of the original code.") + def run_optimized_candidate( self, *, @@ -1808,13 +1813,25 @@ def run_optimized_candidate( ) ) console.rule() - if compare_test_results(baseline_results.behavior_test_results, candidate_behavior_results): + match, diffs = compare_test_results(baseline_results.behavior_test_results, candidate_behavior_results) + if match: logger.info("h3|Test results matched ✅") console.rule() else: - logger.info("h4|Test results did not match the test results of the original code ❌") - console.rule() - return Failure("Test results did not match the test results of the original code.") + result_unmatched_perc = len(diffs) / len(candidate_behavior_results) + if result_unmatched_perc > 0.5: + # if the test unmatched percentage is greater than 50%, we can't fix it + return self.get_results_not_matched_error() + + # with the parsed test results diff ask the llm to fix the candidate to match the test results of the original code, and run again + # self.run_optimized_candidate( + # optimization_candidate_index=optimization_candidate_index, + # baseline_results=baseline_results, + # original_helper_code=original_helper_code, + # file_path_to_helper_classes=file_path_to_helper_classes, + # ) + print(f"should try to fix it, diffs: {diffs}") + return self.get_results_not_matched_error() logger.info(f"loading|Running performance tests for candidate {optimization_candidate_index}...") diff --git a/codeflash/verification/equivalence.py b/codeflash/verification/equivalence.py index 9d7f5ba2c..614fc66b4 100644 --- a/codeflash/verification/equivalence.py +++ b/codeflash/verification/equivalence.py @@ -1,4 +1,6 @@ import sys +from dataclasses import dataclass +from enum import Enum from codeflash.cli_cmds.console import logger from codeflash.models.models import TestResults, TestType, VerificationType @@ -7,21 +9,38 @@ INCREASED_RECURSION_LIMIT = 5000 -def compare_test_results(original_results: TestResults, candidate_results: TestResults) -> bool: +class TestDiffScope(Enum): + RETURN_VALUE = "return_value" + STDOUT = "stdout" + TIMED_OUT = "timed_out" + DID_PASS = "did_pass" # noqa: S105 + + +@dataclass +class TestDiff: + scope: TestDiffScope + test_src_code: str + pytest_error: str + original_value: any + candidate_value: any + + +def compare_test_results(original_results: TestResults, candidate_results: TestResults) -> tuple[bool, list[TestDiff]]: # This is meant to be only called with test results for the first loop index if len(original_results) == 0 or len(candidate_results) == 0: - return False # empty test results are not equal + return False, [] # empty test results are not equal original_recursion_limit = sys.getrecursionlimit() if original_recursion_limit < INCREASED_RECURSION_LIMIT: sys.setrecursionlimit(INCREASED_RECURSION_LIMIT) # Increase recursion limit to avoid RecursionError test_ids_superset = original_results.get_all_unique_invocation_loop_ids().union( set(candidate_results.get_all_unique_invocation_loop_ids()) ) - are_equal: bool = True + test_diffs: list[TestDiff] = [] did_all_timeout: bool = True for test_id in test_ids_superset: original_test_result = original_results.get_by_unique_invocation_loop_id(test_id) cdd_test_result = candidate_results.get_by_unique_invocation_loop_id(test_id) + candidate_pytest_error = candidate_results.test_failures.get(original_test_result.id.test_function_name) if cdd_test_result is not None and original_test_result is None: continue # If helper function instance_state verification is not present, that's ok. continue @@ -32,8 +51,7 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR ): continue if original_test_result is None or cdd_test_result is None: - are_equal = False - break + return False, [] did_all_timeout = did_all_timeout and original_test_result.timed_out if original_test_result.timed_out: continue @@ -43,23 +61,26 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR in {VerificationType.INIT_STATE_HELPER, VerificationType.INIT_STATE_FTO} ): superset_obj = True + test_src_code = original_test_result.id.get_src_code(original_test_result.file_name) if not comparator(original_test_result.return_value, cdd_test_result.return_value, superset_obj=superset_obj): - are_equal = False + test_diffs.append( + TestDiff( + scope=TestDiffScope.RETURN_VALUE, + test_src_code=test_src_code, + original_value=original_test_result.return_value, + candidate_value=cdd_test_result.return_value, + pytest_error=candidate_pytest_error, + ) + ) + try: - logger.debug( - "File Name: %s\n" - "Test Type: %s\n" - "Verification Type: %s\n" - "Invocation ID: %s\n" - "Original return value: %s\n" - "Candidate return value: %s\n" - "-------------------", - original_test_result.file_name, - original_test_result.test_type, - original_test_result.verification_type, - original_test_result.id, - original_test_result.return_value, - cdd_test_result.return_value, + print( + f"File Name: {original_test_result.file_name}\n" + f"Test Type: {original_test_result.test_type}\n" + f"Verification Type: {original_test_result.verification_type}\n" + f"Invocation ID: {original_test_result.id}\n" + f"Original return value: {original_test_result.return_value}\n" + f"Candidate return value: {cdd_test_result.return_value}\n" ) except Exception as e: logger.error(e) @@ -67,7 +88,15 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR if (original_test_result.stdout and cdd_test_result.stdout) and not comparator( original_test_result.stdout, cdd_test_result.stdout ): - are_equal = False + test_diffs.append( + TestDiff( + scope=TestDiffScope.STDOUT, + test_src_code=test_src_code, + original_value=original_test_result.stdout, + candidate_value=cdd_test_result.stdout, + pytest_error=candidate_pytest_error, + ) + ) break if original_test_result.test_type in { @@ -76,9 +105,17 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR TestType.GENERATED_REGRESSION, TestType.REPLAY_TEST, } and (cdd_test_result.did_pass != original_test_result.did_pass): - are_equal = False + test_diffs.append( + TestDiff( + scope=TestDiffScope.DID_PASS, + test_src_code=test_src_code, + original_value=original_test_result.did_pass, + candidate_value=cdd_test_result.did_pass, + pytest_error=candidate_pytest_error, + ) + ) break sys.setrecursionlimit(original_recursion_limit) if did_all_timeout: - return False - return are_equal + return False, test_diffs + return len(test_diffs) == 0, test_diffs diff --git a/codeflash/verification/parse_test_output.py b/codeflash/verification/parse_test_output.py index ef513a0a3..bbcf21adc 100644 --- a/codeflash/verification/parse_test_output.py +++ b/codeflash/verification/parse_test_output.py @@ -512,6 +512,43 @@ def merge_test_results( return merged_test_results +def parse_test_failures_from_stdout(test_results: TestResults, stdout: str) -> TestResults: + stdout_lines = stdout.splitlines() + start_line = -1 + end_line = -1 + for i, line in enumerate(stdout_lines): + if start_line != -1 and end_line != -1: + break + if "FAILURES" in line: + start_line = i + elif "short test summary info" in line: + end_line = i + if start_line == -1 or end_line == -1: + return test_results + + complete_failure_output_lines = stdout_lines[start_line:end_line] # exclude last summary line + + test_case_to_failure: dict[str, str] = {} + + current_test_case: str | None = None + current_failure_lines: list[str] = [] + + for line in complete_failure_output_lines: + if line.startswith("_______"): + if current_test_case: + test_case_to_failure[current_test_case] = "".join(current_failure_lines) + current_test_case = line.strip("_ ").strip() + current_failure_lines = [] + elif current_test_case: + current_failure_lines.append(line + "\n") + + if current_test_case: + test_case_to_failure[current_test_case] = "".join(current_failure_lines) + + test_results.test_failures = test_case_to_failure + return test_results + + def parse_test_results( test_xml_path: Path, test_files: TestFiles, @@ -572,4 +609,9 @@ def parse_test_results( function_name=function_name, ) coverage.log_coverage() + try: + parse_test_failures_from_stdout(results, run_result.stdout) + except Exception as e: + logger.exception(e) + return results, coverage if all_args else None diff --git a/tests/test_codeflash_capture.py b/tests/test_codeflash_capture.py index c326cecc4..5f1bcb858 100644 --- a/tests/test_codeflash_capture.py +++ b/tests/test_codeflash_capture.py @@ -1177,3 +1177,236 @@ def target_function(self): fto_file_path.unlink(missing_ok=True) helper_path_1.unlink(missing_ok=True) helper_path_2.unlink(missing_ok=True) + +def test_instrument_codeflash_capture_and_run_tests_2() -> None: + # End to end run that instruments code and runs tests. Made to be similar to code used in the optimizer.py + test_code = """import math +import pytest +from typing import List, Tuple, Optional +from code_to_optimize.tests.pytest.fto_file import calculate_portfolio_metrics + +def test_calculate_portfolio_metrics(): + # Test case 1: Basic portfolio + investments = [ + ('Stocks', 0.6, 0.12), + ('Bonds', 0.3, 0.04), + ('Cash', 0.1, 0.01) + ] + + result = calculate_portfolio_metrics(investments) + + # Check weighted return calculation + expected_return = 0.6*0.12 + 0.3*0.04 + 0.1*0.01 + assert abs(result['weighted_return'] - expected_return) < 1e-10 + + # Check volatility calculation + expected_vol = math.sqrt((0.6*0.12)**2 + (0.3*0.04)**2 + (0.1*0.01)**2) + assert abs(result['volatility'] - expected_vol) < 1e-10 + + # Check Sharpe ratio + expected_sharpe = (expected_return - 0.02) / expected_vol + assert abs(result['sharpe_ratio'] - expected_sharpe) < 1e-10 + + # Check best/worst performers + assert result['best_performing'][0] == 'Stocks' + assert result['worst_performing'][0] == 'Cash' + assert result['total_assets'] == 3 + +def test_empty_investments(): + with pytest.raises(ValueError, match="Investments list cannot be empty"): + calculate_portfolio_metrics([]) + +def test_weights_not_sum_to_one(): + investments = [('Stock', 0.5, 0.1), ('Bond', 0.4, 0.05)] + with pytest.raises(ValueError, match="Portfolio weights must sum to 1.0"): + calculate_portfolio_metrics(investments) + +def test_zero_volatility(): + investments = [('Cash', 1.0, 0.0)] + result = calculate_portfolio_metrics(investments, risk_free_rate=0.0) + assert result['sharpe_ratio'] == 0.0 + assert result['volatility'] == 0.0 +""" + + original_code = """import math +from typing import List, Tuple, Optional + +def calculate_portfolio_metrics( + investments: List[Tuple[str, float, float]], + risk_free_rate: float = 0.02 +) -> dict: + if not investments: + raise ValueError("Investments list cannot be empty") + + if abs(sum(weight for _, weight, _ in investments) - 1.0) > 1e-10: + raise ValueError("Portfolio weights must sum to 1.0") + + # Calculate weighted return + weighted_return = sum(weight * ret for _, weight, ret in investments) + + # Calculate portfolio volatility (simplified) + volatility = math.sqrt(sum((weight * ret) ** 2 for _, weight, ret in investments)) + + # Calculate Sharpe ratio + if volatility == 0: + sharpe_ratio = 0.0 + else: + sharpe_ratio = (weighted_return - risk_free_rate) / volatility + + # Find best and worst performing assets + best_asset = max(investments, key=lambda x: x[2]) + worst_asset = min(investments, key=lambda x: x[2]) + + return { + 'weighted_return': round(weighted_return, 6), + 'volatility': round(volatility, 6), + 'sharpe_ratio': round(sharpe_ratio, 6), + 'best_performing': (best_asset[0], round(best_asset[2], 6)), + 'worst_performing': (worst_asset[0], round(worst_asset[2], 6)), + 'total_assets': len(investments) + } +""" + test_dir = (Path(__file__).parent.parent / "code_to_optimize" / "tests" / "pytest").resolve() + test_file_name = "test_multiple_helpers.py" + + fto_file_name = "fto_file.py" + + test_path = test_dir / test_file_name + test_path_perf = test_dir / "test_multiple_helpers_perf.py" + fto_file_path = test_dir / fto_file_name + + tests_root = Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/" + project_root_path = (Path(__file__).parent / "..").resolve() + + try: + with fto_file_path.open("w") as f: + f.write(original_code) + with test_path.open("w") as f: + f.write(test_code) + + fto = FunctionToOptimize("calculate_portfolio_metrics", fto_file_path, parents=[]) + file_path_to_helper_class = { + } + instrument_codeflash_capture(fto, file_path_to_helper_class, tests_root) + test_env = os.environ.copy() + test_env["CODEFLASH_TEST_ITERATION"] = "0" + test_env["CODEFLASH_LOOP_INDEX"] = "1" + + test_type = TestType.EXISTING_UNIT_TEST + test_config = TestConfig( + tests_root=tests_root, + tests_project_rootdir=project_root_path, + project_root_path=project_root_path, + test_framework="pytest", + pytest_cmd="pytest", + ) + func_optimizer = FunctionOptimizer(function_to_optimize=fto, test_cfg=test_config) + func_optimizer.test_files = TestFiles( + test_files=[ + TestFile( + instrumented_behavior_file_path=test_path, + test_type=test_type, + original_file_path=test_path, + benchmarking_file_path=test_path_perf, + ) + ] + ) + # Code in optimizer.py + # Instrument codeflash capture + candidate_fto_code = Path(fto.file_path).read_text("utf-8") + candidate_helper_code = {} + for file_path in file_path_to_helper_class: + candidate_helper_code[file_path] = Path(file_path).read_text("utf-8") + file_path_to_helper_classes = { + } + instrument_codeflash_capture(fto, file_path_to_helper_classes, tests_root) + + test_results, coverage_data = func_optimizer.run_and_parse_tests( + testing_type=TestingMode.BEHAVIOR, + test_env=test_env, + test_files=func_optimizer.test_files, + optimization_iteration=0, + pytest_min_loops=1, + pytest_max_loops=1, + testing_time=0.1, + ) + + # Remove instrumentation + FunctionOptimizer.write_code_and_helpers(candidate_fto_code, candidate_helper_code, fto.file_path) + + # Now, let's say we optimize the code and make changes. + new_fto_code = """import math +from typing import List, Tuple, Optional + +def calculate_portfolio_metrics( + investments: List[Tuple[str, float, float]], + risk_free_rate: float = 0.02 +) -> dict: + if not investments: + raise ValueError("Investments list cannot be empty") + + total_weight = sum(w for _, w, _ in investments) + if total_weight != 1.0: # Should use tolerance check + raise ValueError("Portfolio weights must sum to 1.0") + + weighted_return = 1.0 + for _, weight, ret in investments: + weighted_return *= (1 + ret) ** weight + weighted_return = weighted_return - 1.0 # Convert back from geometric + + returns = [r for _, _, r in investments] + mean_return = sum(returns) / len(returns) + volatility = math.sqrt(sum((r - mean_return) ** 2 for r in returns) / len(returns)) + + # BUG 4: Sharpe ratio calculation is correct but uses wrong inputs + if volatility == 0: + sharpe_ratio = 0.0 + else: + sharpe_ratio = (weighted_return - risk_free_rate) / volatility + + def risk_adjusted_return(return_val, weight): + return (return_val - risk_free_rate) / (weight * return_val) if weight * return_val != 0 else return_val + + best_asset = max(investments, key=lambda x: risk_adjusted_return(x[2], x[1])) + worst_asset = min(investments, key=lambda x: risk_adjusted_return(x[2], x[1])) + + return { + "weighted_return": round(weighted_return, 6), + "volatility": 2, + "sharpe_ratio": round(sharpe_ratio, 6), + "best_performing": (best_asset[0], round(best_asset[2], 6)), + "worst_performing": (worst_asset[0], round(worst_asset[2], 6)), + "total_assets": len(investments), + } +""" + with fto_file_path.open("w") as f: + f.write(new_fto_code) + # Instrument codeflash capture + candidate_fto_code = Path(fto.file_path).read_text("utf-8") + candidate_helper_code = {} + for file_path in file_path_to_helper_class: + candidate_helper_code[file_path] = Path(file_path).read_text("utf-8") + file_path_to_helper_classes = { + # Path(helper_path_1): {"HelperClass1"}, + # Path(helper_path_2): {"HelperClass2", "AnotherHelperClass"}, + } + instrument_codeflash_capture(fto, file_path_to_helper_classes, tests_root) + modified_test_results, coverage_data = func_optimizer.run_and_parse_tests( + testing_type=TestingMode.BEHAVIOR, + test_env=test_env, + test_files=func_optimizer.test_files, + optimization_iteration=0, + pytest_min_loops=1, + pytest_max_loops=1, + testing_time=0.1, + ) + # Remove instrumentation + FunctionOptimizer.write_code_and_helpers(candidate_fto_code, candidate_helper_code, fto.file_path) + matched, diffs = compare_test_results(test_results, modified_test_results) + print(diffs) + + assert not matched + + finally: + test_path.unlink(missing_ok=True) + fto_file_path.unlink(missing_ok=True) \ No newline at end of file From 3e0440bf471a20b60b1c48d9862fb78f9902d113 Mon Sep 17 00:00:00 2001 From: ali Date: Thu, 27 Nov 2025 16:25:33 +0200 Subject: [PATCH 02/22] safter --- codeflash/verification/equivalence.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/codeflash/verification/equivalence.py b/codeflash/verification/equivalence.py index 614fc66b4..1e46f2ebc 100644 --- a/codeflash/verification/equivalence.py +++ b/codeflash/verification/equivalence.py @@ -40,7 +40,17 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR for test_id in test_ids_superset: original_test_result = original_results.get_by_unique_invocation_loop_id(test_id) cdd_test_result = candidate_results.get_by_unique_invocation_loop_id(test_id) - candidate_pytest_error = candidate_results.test_failures.get(original_test_result.id.test_function_name) + candidate_test_failures = candidate_results.test_failures + # original_test_failures = original_results.test_failures + cdd_pytest_error = ( + candidate_test_failures.get(original_test_result.id.test_function_name, "") + if candidate_test_failures + else "" + ) + # original_pytest_error = ( + # original_test_failures.get(original_test_result.id.test_function_name, "") if original_test_failures else "" + # ) + if cdd_test_result is not None and original_test_result is None: continue # If helper function instance_state verification is not present, that's ok. continue @@ -69,7 +79,7 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR test_src_code=test_src_code, original_value=original_test_result.return_value, candidate_value=cdd_test_result.return_value, - pytest_error=candidate_pytest_error, + pytest_error=cdd_pytest_error, ) ) @@ -94,7 +104,7 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR test_src_code=test_src_code, original_value=original_test_result.stdout, candidate_value=cdd_test_result.stdout, - pytest_error=candidate_pytest_error, + pytest_error=cdd_pytest_error, ) ) break @@ -111,7 +121,7 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR test_src_code=test_src_code, original_value=original_test_result.did_pass, candidate_value=cdd_test_result.did_pass, - pytest_error=candidate_pytest_error, + pytest_error=cdd_pytest_error, ) ) break From eb16cb2a0b1aa3be640b6ce42cb28c5930c3fbeb Mon Sep 17 00:00:00 2001 From: "codeflash-ai[bot]" <148906541+codeflash-ai[bot]@users.noreply.github.com> Date: Thu, 27 Nov 2025 14:49:05 +0000 Subject: [PATCH 03/22] Optimize parse_test_failures_from_stdout The optimized code achieves a **15% speedup** through several targeted micro-optimizations that reduce computational overhead in the parsing loop: **Key Optimizations:** 1. **Single-pass boundary search**: Instead of checking both conditions (`start_line != -1 and end_line != -1`) on every iteration, the optimized version uses `None` values and breaks immediately when both markers are found, eliminating redundant condition checks. 2. **Fast-path string matching**: Before calling the expensive `.startswith("_______")` method, it first checks if `line[0] == "_"`, avoiding the method call for most lines that don't start with underscores. 3. **Method lookup optimization**: Pulls `current_failure_lines.append` into a local variable to avoid repeated attribute lookups in the hot loop where failure lines are processed. 4. **Memory-efficient list management**: Uses `current_failure_lines.clear()` instead of creating new list objects (`current_failure_lines = []`), reducing object allocation pressure. **Performance Impact:** The optimizations show the most significant gains in large-scale scenarios: - **Large failure sets**: 14.2% faster with 500 failures, 14.0% faster with 999 failures - **Large output**: 29.2% faster for single failures with 1000 lines of output - **Complex scenarios**: 22.3% faster with 50 cases having 10 lines each **Hot Path Context:** Based on the function reference, `parse_test_failures_from_stdout` is called from `parse_test_results`, which appears to be part of a test optimization pipeline. The function processes pytest stdout to extract failure information, making it performance-critical when dealing with large test suites or verbose test outputs. The 15% improvement becomes meaningful when processing hundreds of test failures in CI/CD environments or during iterative code optimization workflows. --- codeflash/verification/parse_test_output.py | 31 ++++++++++++++------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/codeflash/verification/parse_test_output.py b/codeflash/verification/parse_test_output.py index bbcf21adc..47ad5738a 100644 --- a/codeflash/verification/parse_test_output.py +++ b/codeflash/verification/parse_test_output.py @@ -514,16 +514,17 @@ def merge_test_results( def parse_test_failures_from_stdout(test_results: TestResults, stdout: str) -> TestResults: stdout_lines = stdout.splitlines() - start_line = -1 - end_line = -1 + start_line = end_line = None + + # optimize search for start/end by scanning once for i, line in enumerate(stdout_lines): - if start_line != -1 and end_line != -1: - break - if "FAILURES" in line: + if start_line is None and "FAILURES" in line: start_line = i - elif "short test summary info" in line: + elif start_line is not None and end_line is None and "short test summary info" in line: end_line = i - if start_line == -1 or end_line == -1: + break + + if start_line is None or end_line is None: return test_results complete_failure_output_lines = stdout_lines[start_line:end_line] # exclude last summary line @@ -533,14 +534,24 @@ def parse_test_failures_from_stdout(test_results: TestResults, stdout: str) -> T current_test_case: str | None = None current_failure_lines: list[str] = [] + # Avoid per-line string concatenation by tracking indices and performing join once per section + # Precompute the boundary check value + underline_prefix = "_______" + + # Minor: Pull into local variable to avoid attribute lookup inside loop + join_nl = "\n".join + append = current_failure_lines.append + for line in complete_failure_output_lines: - if line.startswith("_______"): + # Fast-path: avoid .startswith() unless it can possibly match + if line and line[0] == "_" and line.startswith(underline_prefix): if current_test_case: test_case_to_failure[current_test_case] = "".join(current_failure_lines) current_test_case = line.strip("_ ").strip() - current_failure_lines = [] + # Start new collection + current_failure_lines.clear() elif current_test_case: - current_failure_lines.append(line + "\n") + append(line + "\n") if current_test_case: test_case_to_failure[current_test_case] = "".join(current_failure_lines) From a7f8816f5ed8f4035b8c427b99737c46c06fa8f4 Mon Sep 17 00:00:00 2001 From: ali Date: Thu, 27 Nov 2025 19:51:55 +0200 Subject: [PATCH 04/22] fix tests --- codeflash/models/models.py | 2 ++ codeflash/verification/equivalence.py | 8 ++++++- tests/test_codeflash_capture.py | 18 +++++++++----- tests/test_comparator.py | 21 ++++++++++------ tests/test_instrument_all_and_run.py | 24 ++++++++++++------- ...t_instrumentation_run_results_aiservice.py | 10 ++++---- tests/test_pickle_patcher.py | 8 +++---- 7 files changed, 61 insertions(+), 30 deletions(-) diff --git a/codeflash/models/models.py b/codeflash/models/models.py index 647aa2a3d..48ecf396a 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -513,6 +513,8 @@ def find_func_in_class(self, class_node: cst.ClassDef, func_name: str) -> Option return None def get_src_code(self, test_path: Path) -> Optional[str]: + if not test_path.exists(): + return None test_src = test_path.read_text(encoding="utf-8") module_node = cst.parse_module(test_src) diff --git a/codeflash/verification/equivalence.py b/codeflash/verification/equivalence.py index 1e46f2ebc..77798d88f 100644 --- a/codeflash/verification/equivalence.py +++ b/codeflash/verification/equivalence.py @@ -1,11 +1,17 @@ +from __future__ import annotations + import sys from dataclasses import dataclass from enum import Enum +from typing import TYPE_CHECKING, Optional from codeflash.cli_cmds.console import logger from codeflash.models.models import TestResults, TestType, VerificationType from codeflash.verification.comparator import comparator +if TYPE_CHECKING: + from codeflash.models.models import TestResults + INCREASED_RECURSION_LIMIT = 5000 @@ -19,10 +25,10 @@ class TestDiffScope(Enum): @dataclass class TestDiff: scope: TestDiffScope - test_src_code: str pytest_error: str original_value: any candidate_value: any + test_src_code: Optional[str] = None def compare_test_results(original_results: TestResults, candidate_results: TestResults) -> tuple[bool, list[TestDiff]]: diff --git a/tests/test_codeflash_capture.py b/tests/test_codeflash_capture.py index 5f1bcb858..4ab52dd80 100644 --- a/tests/test_codeflash_capture.py +++ b/tests/test_codeflash_capture.py @@ -502,7 +502,8 @@ def __init__(self, x=2): pytest_max_loops=1, testing_time=0.1, ) - assert compare_test_results(test_results, test_results2) + match, _ = compare_test_results(test_results, test_results2) + assert match finally: test_path.unlink(missing_ok=True) @@ -626,7 +627,8 @@ def __init__(self, *args, **kwargs): testing_time=0.1, ) - assert compare_test_results(test_results, results2) + match, _ = compare_test_results(test_results, results2) + assert match finally: test_path.unlink(missing_ok=True) @@ -754,7 +756,8 @@ def __init__(self, x=2): testing_time=0.1, ) - assert compare_test_results(test_results, test_results2) + match, _ = compare_test_results(test_results, test_results2) + assert match finally: test_path.unlink(missing_ok=True) sample_code_path.unlink(missing_ok=True) @@ -902,7 +905,8 @@ def another_helper(self): testing_time=0.1, ) - assert compare_test_results(test_results, results2) + match, _ = compare_test_results(test_results, results2) + assert match finally: test_path.unlink(missing_ok=True) @@ -1132,7 +1136,8 @@ def target_function(self): ) # Remove instrumentation FunctionOptimizer.write_code_and_helpers(candidate_fto_code, candidate_helper_code, fto.file_path) - assert not compare_test_results(test_results, mutated_test_results) + match, _ = compare_test_results(test_results, mutated_test_results) + assert not match # This fto code stopped using a helper class. it should still pass no_helper1_fto_code = """ @@ -1170,7 +1175,8 @@ def target_function(self): ) # Remove instrumentation FunctionOptimizer.write_code_and_helpers(candidate_fto_code, candidate_helper_code, fto.file_path) - assert compare_test_results(test_results, no_helper1_test_results) + match, _ = compare_test_results(test_results, no_helper1_test_results) + assert match finally: test_path.unlink(missing_ok=True) diff --git a/tests/test_comparator.py b/tests/test_comparator.py index 06d178f95..6c2781229 100644 --- a/tests/test_comparator.py +++ b/tests/test_comparator.py @@ -1176,7 +1176,8 @@ def test_compare_results_fn(): ) ) - assert compare_test_results(original_results, new_results_1) + match, _ = compare_test_results(original_results, new_results_1) + assert match new_results_2 = TestResults() new_results_2.add( @@ -1199,7 +1200,8 @@ def test_compare_results_fn(): ) ) - assert not compare_test_results(original_results, new_results_2) + match, _ = compare_test_results(original_results, new_results_2) + assert not match new_results_3 = TestResults() new_results_3.add( @@ -1241,7 +1243,8 @@ def test_compare_results_fn(): ) ) - assert compare_test_results(original_results, new_results_3) + match, _ = compare_test_results(original_results, new_results_3) + assert match new_results_4 = TestResults() new_results_4.add( @@ -1264,7 +1267,8 @@ def test_compare_results_fn(): ) ) - assert not compare_test_results(original_results, new_results_4) + match, _ = compare_test_results(original_results, new_results_4) + assert not match new_results_5_baseline = TestResults() new_results_5_baseline.add( @@ -1308,7 +1312,8 @@ def test_compare_results_fn(): ) ) - assert not compare_test_results(new_results_5_baseline, new_results_5_opt) + match, _ = compare_test_results(new_results_5_baseline, new_results_5_opt) + assert not match new_results_6_baseline = TestResults() new_results_6_baseline.add( @@ -1352,9 +1357,11 @@ def test_compare_results_fn(): ) ) - assert not compare_test_results(new_results_6_baseline, new_results_6_opt) + match, _ = compare_test_results(new_results_6_baseline, new_results_6_opt) + assert not match - assert not compare_test_results(TestResults(), TestResults()) + match, _ = compare_test_results(TestResults(), TestResults()) + assert not match def test_exceptions(): diff --git a/tests/test_instrument_all_and_run.py b/tests/test_instrument_all_and_run.py index ece7d38b0..7bdfa364b 100644 --- a/tests/test_instrument_all_and_run.py +++ b/tests/test_instrument_all_and_run.py @@ -223,7 +223,8 @@ def test_sort(): result: [0, 1, 2, 3, 4, 5] """ assert out_str == results2[0].stdout - assert compare_test_results(test_results, results2) + match, _ = compare_test_results(test_results, results2) + assert match finally: fto_path.write_text(original_code, "utf-8") test_path.unlink(missing_ok=True) @@ -368,7 +369,8 @@ def test_sort(): assert test_results[1].return_value == ([0, 1, 2, 3, 4, 5],) out_str = """codeflash stdout : BubbleSorter.sorter() called\n""" assert test_results[1].stdout == out_str - assert compare_test_results(test_results, test_results) + match, _ = compare_test_results(test_results, test_results) + assert match assert test_results[2].id.function_getting_tested == "BubbleSorter.__init__" assert test_results[2].id.test_function_name == "test_sort" assert test_results[2].did_pass @@ -396,7 +398,8 @@ def test_sort(): testing_time=0.1, ) - assert compare_test_results(test_results, results2) + match, _ = compare_test_results(test_results, results2) + assert match # Replace with optimized code that mutated instance attribute optimized_code = """ @@ -491,7 +494,8 @@ def sorter(self, arr): ) assert new_test_results[3].runtime > 0 assert new_test_results[3].did_pass - assert not compare_test_results(test_results, new_test_results) + match, _ = compare_test_results(test_results, new_test_results) + assert not match finally: fto_path.write_text(original_code, "utf-8") @@ -630,7 +634,8 @@ def test_sort(): out_str = """codeflash stdout : BubbleSorter.sorter_classmethod() called """ assert test_results[0].stdout == out_str - assert compare_test_results(test_results, test_results) + match, _ = compare_test_results(test_results, test_results) + assert match assert test_results[1].id.function_getting_tested == "BubbleSorter.sorter_classmethod" assert test_results[1].id.iteration_id == "4_0" @@ -655,7 +660,8 @@ def test_sort(): testing_time=0.1, ) - assert compare_test_results(test_results, results2) + match, _ = compare_test_results(test_results, results2) + assert match finally: fto_path.write_text(original_code, "utf-8") @@ -794,7 +800,8 @@ def test_sort(): out_str = """codeflash stdout : BubbleSorter.sorter_staticmethod() called """ assert test_results[0].stdout == out_str - assert compare_test_results(test_results, test_results) + match, _ = compare_test_results(test_results, test_results) + assert match assert test_results[1].id.function_getting_tested == "BubbleSorter.sorter_staticmethod" assert test_results[1].id.iteration_id == "4_0" @@ -819,7 +826,8 @@ def test_sort(): testing_time=0.1, ) - assert compare_test_results(test_results, results2) + match, _ = compare_test_results(test_results, results2) + assert match finally: fto_path.write_text(original_code, "utf-8") diff --git a/tests/test_instrumentation_run_results_aiservice.py b/tests/test_instrumentation_run_results_aiservice.py index cae2c76f1..03556718d 100644 --- a/tests/test_instrumentation_run_results_aiservice.py +++ b/tests/test_instrumentation_run_results_aiservice.py @@ -221,10 +221,10 @@ def sorter(self, arr): testing_time=0.1, ) # assert test_results_mutated_attr[0].return_value[1]["self"].x == 1 TODO: add self as input to function - assert compare_test_results( + match, _ = compare_test_results( test_results, test_results_mutated_attr ) # Without codeflash capture, the init state was not verified, and the results are verified as correct even with the attribute mutated - + assert match assert test_results_mutated_attr[0].stdout == "codeflash stdout : BubbleSorter.sorter() called\n" finally: fto_path.write_text(original_code, "utf-8") @@ -403,9 +403,10 @@ def sorter(self, arr): assert test_results_mutated_attr[0].return_value[0] == {"x": 1} assert test_results_mutated_attr[0].verification_type == VerificationType.INIT_STATE_FTO assert test_results_mutated_attr[0].stdout == "" - assert not compare_test_results( + match,_ = compare_test_results( test_results, test_results_mutated_attr ) # The test should fail because the instance attribute was mutated + assert not match # Replace with optimized code that did not mutate existing instance attribute, but added a new one optimized_code_new_attr = """ import sys @@ -457,9 +458,10 @@ def sorter(self, arr): assert test_results_new_attr[0].stdout == "" # assert test_results_new_attr[1].return_value[1]["self"].x == 0 TODO: add self as input # assert test_results_new_attr[1].return_value[1]["self"].y == 2 TODO: add self as input - assert compare_test_results( + match,_ = compare_test_results( test_results, test_results_new_attr ) # The test should pass because the instance attribute was not mutated, only a new one was added + assert match finally: fto_path.write_text(original_code, "utf-8") test_path.unlink(missing_ok=True) diff --git a/tests/test_pickle_patcher.py b/tests/test_pickle_patcher.py index c67883c12..c05384d03 100644 --- a/tests/test_pickle_patcher.py +++ b/tests/test_pickle_patcher.py @@ -427,8 +427,8 @@ def bubble_sort_with_unused_socket(data_container): testing_time=1.0, ) assert len(optimized_test_results_unused_socket) == 1 - verification_result = compare_test_results(test_results_unused_socket, optimized_test_results_unused_socket) - assert verification_result is True + match, _ = compare_test_results(test_results_unused_socket, optimized_test_results_unused_socket) + assert match # Remove the previous instrumentation replay_test_path.write_text(original_replay_test_code) @@ -517,8 +517,8 @@ def bubble_sort_with_used_socket(data_container): assert test_results_used_socket.test_results[0].did_pass is False # Even though tests threw the same error, we reject this as the behavior of the unpickleable object could not be determined. - assert compare_test_results(test_results_used_socket, optimized_test_results_used_socket) is False - + match, _ = compare_test_results(test_results_used_socket, optimized_test_results_used_socket) + assert not match finally: # cleanup output_file.unlink(missing_ok=True) From 4e9f8941d4b9b91a4575271c17ad8070e8f36549 Mon Sep 17 00:00:00 2001 From: ali Date: Thu, 27 Nov 2025 19:56:02 +0200 Subject: [PATCH 05/22] linting --- codeflash/verification/parse_test_output.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/codeflash/verification/parse_test_output.py b/codeflash/verification/parse_test_output.py index 47ad5738a..873f9f341 100644 --- a/codeflash/verification/parse_test_output.py +++ b/codeflash/verification/parse_test_output.py @@ -516,7 +516,6 @@ def parse_test_failures_from_stdout(test_results: TestResults, stdout: str) -> T stdout_lines = stdout.splitlines() start_line = end_line = None - # optimize search for start/end by scanning once for i, line in enumerate(stdout_lines): if start_line is None and "FAILURES" in line: start_line = i @@ -534,24 +533,16 @@ def parse_test_failures_from_stdout(test_results: TestResults, stdout: str) -> T current_test_case: str | None = None current_failure_lines: list[str] = [] - # Avoid per-line string concatenation by tracking indices and performing join once per section - # Precompute the boundary check value underline_prefix = "_______" - # Minor: Pull into local variable to avoid attribute lookup inside loop - join_nl = "\n".join - append = current_failure_lines.append - for line in complete_failure_output_lines: - # Fast-path: avoid .startswith() unless it can possibly match if line and line[0] == "_" and line.startswith(underline_prefix): if current_test_case: test_case_to_failure[current_test_case] = "".join(current_failure_lines) current_test_case = line.strip("_ ").strip() - # Start new collection current_failure_lines.clear() elif current_test_case: - append(line + "\n") + current_failure_lines.append(line + "\n") if current_test_case: test_case_to_failure[current_test_case] = "".join(current_failure_lines) From 1c9abaf0ff298f0a772c93c3d9facdd299cdfadb Mon Sep 17 00:00:00 2001 From: ali Date: Fri, 28 Nov 2025 10:49:07 +0200 Subject: [PATCH 06/22] did it pass ? --- codeflash/optimization/function_optimizer.py | 2 +- codeflash/verification/equivalence.py | 10 +++++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index d948da052..9f3d1cfd0 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -1823,6 +1823,7 @@ def run_optimized_candidate( # if the test unmatched percentage is greater than 50%, we can't fix it return self.get_results_not_matched_error() + print(f"should try to fix it, diffs: {diffs}") # with the parsed test results diff ask the llm to fix the candidate to match the test results of the original code, and run again # self.run_optimized_candidate( # optimization_candidate_index=optimization_candidate_index, @@ -1830,7 +1831,6 @@ def run_optimized_candidate( # original_helper_code=original_helper_code, # file_path_to_helper_classes=file_path_to_helper_classes, # ) - print(f"should try to fix it, diffs: {diffs}") return self.get_results_not_matched_error() logger.info(f"loading|Running performance tests for candidate {optimization_candidate_index}...") diff --git a/codeflash/verification/equivalence.py b/codeflash/verification/equivalence.py index 77798d88f..d39433416 100644 --- a/codeflash/verification/equivalence.py +++ b/codeflash/verification/equivalence.py @@ -18,8 +18,8 @@ class TestDiffScope(Enum): RETURN_VALUE = "return_value" STDOUT = "stdout" - TIMED_OUT = "timed_out" DID_PASS = "did_pass" # noqa: S105 + TIMED_OUT = "timed_out" @dataclass @@ -28,6 +28,8 @@ class TestDiff: pytest_error: str original_value: any candidate_value: any + original_pass: bool + candidate_pass: bool test_src_code: Optional[str] = None @@ -86,6 +88,8 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR original_value=original_test_result.return_value, candidate_value=cdd_test_result.return_value, pytest_error=cdd_pytest_error, + original_pass=original_test_result.did_pass, + candidate_pass=cdd_test_result.did_pass, ) ) @@ -111,6 +115,8 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR original_value=original_test_result.stdout, candidate_value=cdd_test_result.stdout, pytest_error=cdd_pytest_error, + original_pass=original_test_result.did_pass, + candidate_pass=cdd_test_result.did_pass, ) ) break @@ -128,6 +134,8 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR original_value=original_test_result.did_pass, candidate_value=cdd_test_result.did_pass, pytest_error=cdd_pytest_error, + original_pass=original_test_result.did_pass, + candidate_pass=cdd_test_result.did_pass, ) ) break From 0b2d894ce2ed7242ee7fc76ff676f6eef88ecbee Mon Sep 17 00:00:00 2001 From: ali Date: Fri, 28 Nov 2025 11:03:59 +0200 Subject: [PATCH 07/22] revert test optimization --- codeflash/discovery/functions_to_optimize.py | 42 ++++++-------------- 1 file changed, 12 insertions(+), 30 deletions(-) diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index 3bee5fbf9..b8cf895e1 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -306,43 +306,25 @@ def levenshtein_distance(s1: str, s2: str) -> int: len1 = len(s1) len2 = len(s2) # Use a preallocated list instead of creating a new list every iteration - - # Early exit for empty string cases - if len1 == 0: - return len2 - if len2 == 0: - return len1 - - # Convert strings to lists for fast indexed access - s1_list = list(s1) - s2_list = list(s2) - - # Preallocate and reuse arrays; avoid creating new ones every iteration previous = list(range(len1 + 1)) current = [0] * (len1 + 1) for index2 in range(len2): - char2 = s2_list[index2] + char2 = s2[index2] current[0] = index2 + 1 - - # Remove redundant intermediate assignments for better cache locality - prev = previous - curr = current - s1_chars = s1_list - # Use local variables for frequently accessed values for index1 in range(len1): - # Unrolling char1 assignment and equality check - if s1_chars[index1] == char2: - curr[index1 + 1] = prev[index1] + char1 = s1[index1] + if char1 == char2: + current[index1 + 1] = previous[index1] else: - x = prev[index1] - y = prev[index1 + 1] - z = curr[index1] - min_xy = min(x, y) - min_xyz = min(z, min_xy) - curr[index1 + 1] = 1 + min_xyz - - # Swap references rather than copying data + # Fast min calculation without tuple construct + a = previous[index1] + b = previous[index1 + 1] + c = current[index1] + min_val = min(b, a) + min_val = min(c, min_val) + current[index1 + 1] = 1 + min_val + # Swap references instead of copying previous, current = current, previous return previous[len1] From ecfa89fb344b3654df7b75b3d83bcf7929f7726c Mon Sep 17 00:00:00 2001 From: ali Date: Fri, 28 Nov 2025 11:22:16 +0200 Subject: [PATCH 08/22] cleaner --- codeflash/verification/equivalence.py | 70 ++++++++++++--------------- 1 file changed, 31 insertions(+), 39 deletions(-) diff --git a/codeflash/verification/equivalence.py b/codeflash/verification/equivalence.py index d39433416..27f020a9f 100644 --- a/codeflash/verification/equivalence.py +++ b/codeflash/verification/equivalence.py @@ -25,12 +25,14 @@ class TestDiffScope(Enum): @dataclass class TestDiff: scope: TestDiffScope - pytest_error: str original_value: any candidate_value: any original_pass: bool candidate_pass: bool + test_src_code: Optional[str] = None + candidate_pytest_error: Optional[str] = None + original_pytest_error: Optional[str] = None def compare_test_results(original_results: TestResults, candidate_results: TestResults) -> tuple[bool, list[TestDiff]]: @@ -49,15 +51,15 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR original_test_result = original_results.get_by_unique_invocation_loop_id(test_id) cdd_test_result = candidate_results.get_by_unique_invocation_loop_id(test_id) candidate_test_failures = candidate_results.test_failures - # original_test_failures = original_results.test_failures + original_test_failures = original_results.test_failures cdd_pytest_error = ( candidate_test_failures.get(original_test_result.id.test_function_name, "") if candidate_test_failures else "" ) - # original_pytest_error = ( - # original_test_failures.get(original_test_result.id.test_function_name, "") if original_test_failures else "" - # ) + original_pytest_error = ( + original_test_failures.get(original_test_result.id.test_function_name, "") if original_test_failures else "" + ) if cdd_test_result is not None and original_test_result is None: continue @@ -79,22 +81,26 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR in {VerificationType.INIT_STATE_HELPER, VerificationType.INIT_STATE_FTO} ): superset_obj = True + test_src_code = original_test_result.id.get_src_code(original_test_result.file_name) + test_diff = TestDiff( + scope=TestDiffScope.RETURN_VALUE, + original_value=original_test_result.return_value, + candidate_value=cdd_test_result.return_value, + test_src_code=test_src_code, + candidate_pytest_error=cdd_pytest_error, + original_pass=original_test_result.did_pass, + candidate_pass=cdd_test_result.did_pass, + original_pytest_error=original_pytest_error, + ) if not comparator(original_test_result.return_value, cdd_test_result.return_value, superset_obj=superset_obj): - test_diffs.append( - TestDiff( - scope=TestDiffScope.RETURN_VALUE, - test_src_code=test_src_code, - original_value=original_test_result.return_value, - candidate_value=cdd_test_result.return_value, - pytest_error=cdd_pytest_error, - original_pass=original_test_result.did_pass, - candidate_pass=cdd_test_result.did_pass, - ) - ) + test_diff.scope = TestDiffScope.RETURN_VALUE + test_diff.original_value = original_test_result.return_value + test_diff.candidate_value = cdd_test_result.return_value + test_diffs.append(test_diff) try: - print( + logger.debug( f"File Name: {original_test_result.file_name}\n" f"Test Type: {original_test_result.test_type}\n" f"Verification Type: {original_test_result.verification_type}\n" @@ -108,17 +114,10 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR if (original_test_result.stdout and cdd_test_result.stdout) and not comparator( original_test_result.stdout, cdd_test_result.stdout ): - test_diffs.append( - TestDiff( - scope=TestDiffScope.STDOUT, - test_src_code=test_src_code, - original_value=original_test_result.stdout, - candidate_value=cdd_test_result.stdout, - pytest_error=cdd_pytest_error, - original_pass=original_test_result.did_pass, - candidate_pass=cdd_test_result.did_pass, - ) - ) + test_diff.scope = TestDiffScope.STDOUT + test_diff.original_value = original_test_result.stdout + test_diff.candidate_value = cdd_test_result.stdout + test_diffs.append(test_diff) break if original_test_result.test_type in { @@ -127,17 +126,10 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR TestType.GENERATED_REGRESSION, TestType.REPLAY_TEST, } and (cdd_test_result.did_pass != original_test_result.did_pass): - test_diffs.append( - TestDiff( - scope=TestDiffScope.DID_PASS, - test_src_code=test_src_code, - original_value=original_test_result.did_pass, - candidate_value=cdd_test_result.did_pass, - pytest_error=cdd_pytest_error, - original_pass=original_test_result.did_pass, - candidate_pass=cdd_test_result.did_pass, - ) - ) + test_diff.scope = TestDiffScope.DID_PASS + test_diff.original_value = original_test_result.did_pass + test_diff.candidate_value = cdd_test_result.did_pass + test_diffs.append(test_diff) break sys.setrecursionlimit(original_recursion_limit) if did_all_timeout: From 6ea2545aeed0a077f1669375444dfdfffd4f7e25 Mon Sep 17 00:00:00 2001 From: ali Date: Fri, 28 Nov 2025 13:39:00 +0200 Subject: [PATCH 09/22] test: try to fix the candidate and see if the diff is empty --- tests/test_codeflash_capture.py | 69 +++++++++++++++++++++++++++++++-- 1 file changed, 65 insertions(+), 4 deletions(-) diff --git a/tests/test_codeflash_capture.py b/tests/test_codeflash_capture.py index 4ab52dd80..bd6518b24 100644 --- a/tests/test_codeflash_capture.py +++ b/tests/test_codeflash_capture.py @@ -1392,10 +1392,7 @@ def risk_adjusted_return(return_val, weight): candidate_helper_code = {} for file_path in file_path_to_helper_class: candidate_helper_code[file_path] = Path(file_path).read_text("utf-8") - file_path_to_helper_classes = { - # Path(helper_path_1): {"HelperClass1"}, - # Path(helper_path_2): {"HelperClass2", "AnotherHelperClass"}, - } + file_path_to_helper_classes = {} instrument_codeflash_capture(fto, file_path_to_helper_classes, tests_root) modified_test_results, coverage_data = func_optimizer.run_and_parse_tests( testing_type=TestingMode.BEHAVIOR, @@ -1413,6 +1410,70 @@ def risk_adjusted_return(return_val, weight): assert not matched + new_fixed_code = """import math +from typing import List, Tuple, Optional + +def calculate_portfolio_metrics( + investments: List[Tuple[str, float, float]], + risk_free_rate: float = 0.02 +) -> dict: + if not investments: + raise ValueError("Investments list cannot be empty") + + # Tolerant weight check (matches original) + total_weight = sum(weight for _, weight, _ in investments) + if abs(total_weight - 1.0) > 1e-10: + raise ValueError("Portfolio weights must sum to 1.0") + + # Same weighted return as original + weighted_return = sum(weight * ret for _, weight, ret in investments) + + # Same volatility formula as original + volatility = math.sqrt(sum((weight * ret) ** 2 for _, weight, ret in investments)) + + # Same Sharpe ratio logic + if volatility == 0: + sharpe_ratio = 0.0 + else: + sharpe_ratio = (weighted_return - risk_free_rate) / volatility + + # Same best/worst logic (based on return only) + best_asset = max(investments, key=lambda x: x[2]) + worst_asset = min(investments, key=lambda x: x[2]) + + return { + "weighted_return": round(weighted_return, 6), + "volatility": round(volatility, 6), + "sharpe_ratio": round(sharpe_ratio, 6), + "best_performing": (best_asset[0], round(best_asset[2], 6)), + "worst_performing": (worst_asset[0], round(worst_asset[2], 6)), + "total_assets": len(investments), + } +""" + with fto_file_path.open("w") as f: + f.write(new_fixed_code) + candidate_fto_code = Path(fto.file_path).read_text("utf-8") + candidate_helper_code = {} + for file_path in file_path_to_helper_class: + candidate_helper_code[file_path] = Path(file_path).read_text("utf-8") + file_path_to_helper_classes = {} + instrument_codeflash_capture(fto, file_path_to_helper_classes, tests_root) + modified_test_results_2, coverage_data = func_optimizer.run_and_parse_tests( + testing_type=TestingMode.BEHAVIOR, + test_env=test_env, + test_files=func_optimizer.test_files, + optimization_iteration=0, + pytest_min_loops=1, + pytest_max_loops=1, + testing_time=0.1, + ) + # Remove instrumentation + FunctionOptimizer.write_code_and_helpers(candidate_fto_code, candidate_helper_code, fto.file_path) + matched, diffs = compare_test_results(test_results, modified_test_results_2) + # now the test should match and no diffs should be found + assert len(diffs) == 0 + assert matched + finally: test_path.unlink(missing_ok=True) fto_file_path.unlink(missing_ok=True) \ No newline at end of file From fe68772bd421659f27ae3b2589394ed594128c33 Mon Sep 17 00:00:00 2001 From: Codeflash Bot Date: Sun, 30 Nov 2025 17:11:36 -0500 Subject: [PATCH 10/22] capture all test discrepancies --- codeflash/verification/equivalence.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/codeflash/verification/equivalence.py b/codeflash/verification/equivalence.py index d39433416..dc0dda952 100644 --- a/codeflash/verification/equivalence.py +++ b/codeflash/verification/equivalence.py @@ -119,7 +119,6 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR candidate_pass=cdd_test_result.did_pass, ) ) - break if original_test_result.test_type in { TestType.EXISTING_UNIT_TEST, @@ -138,7 +137,6 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR candidate_pass=cdd_test_result.did_pass, ) ) - break sys.setrecursionlimit(original_recursion_limit) if did_all_timeout: return False, test_diffs From ed39ec8bb2978fe4ae1200342a7081dc2087b9c2 Mon Sep 17 00:00:00 2001 From: Codeflash Bot Date: Sun, 30 Nov 2025 18:28:10 -0500 Subject: [PATCH 11/22] do the repair in main loop --- codeflash/optimization/function_optimizer.py | 53 ++++++++++++++------ 1 file changed, 37 insertions(+), 16 deletions(-) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 9f3d1cfd0..a13e3058e 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -13,6 +13,7 @@ from typing import TYPE_CHECKING import libcst as cst +import sentry_sdk from rich.console import Group from rich.panel import Panel from rich.syntax import Syntax @@ -583,6 +584,7 @@ def determine_best_candidate( baseline_results=original_code_baseline, original_helper_code=original_helper_code, file_path_to_helper_classes=file_path_to_helper_classes, + code_context=code_context, ) console.rule() if not is_successful(run_results): @@ -672,21 +674,21 @@ def determine_best_candidate( async_throughput=candidate_result.async_throughput, ) valid_optimizations.append(best_optimization) - # queue corresponding refined optimization for best optimization - if not candidate.optimization_id.endswith("refi"): - future_all_refinements.append( - self.refine_optimizations( - valid_optimizations=[best_optimization], - original_code_baseline=original_code_baseline, - code_context=code_context, - trace_id=self.function_trace_id[:-4] + exp_type - if self.experiment_id - else self.function_trace_id, - ai_service_client=ai_service_client, - executor=self.executor, - function_references=function_references, - ) - ) + # # queue corresponding refined optimization for best optimization + # if not candidate.optimization_id.endswith("refi"): + # future_all_refinements.append( + # self.refine_optimizations( + # valid_optimizations=[best_optimization], + # original_code_baseline=original_code_baseline, + # code_context=code_context, + # trace_id=self.function_trace_id[:-4] + exp_type + # if self.experiment_id + # else self.function_trace_id, + # ai_service_client=ai_service_client, + # executor=self.executor, + # function_references=function_references, + # ) + # ) else: # For async functions, prioritize throughput metrics over runtime even for slow candidates is_async = ( @@ -1813,6 +1815,7 @@ def run_optimized_candidate( ) ) console.rule() + # print(type(code_context), type(candidate)) match, diffs = compare_test_results(baseline_results.behavior_test_results, candidate_behavior_results) if match: logger.info("h3|Test results matched ✅") @@ -1823,7 +1826,25 @@ def run_optimized_candidate( # if the test unmatched percentage is greater than 50%, we can't fix it return self.get_results_not_matched_error() - print(f"should try to fix it, diffs: {diffs}") + logger.info("running code repair...") + # not sure if all return types will be convertible to string + diff_per_test_fn = {} + for diff in diffs: + try: + diff_per_test_fn.setdefault(diff.test_src_code, []).append( + f"Expected Value: {diff.original_value!s}\nActual Value: {diff.candidate_value!s}\nError String:{diff.pytest_error}" + ) + except Exception as e: + sentry_sdk.capture_exception(e) + logger.exception(e) + try: + test_issues = "\n".join( + f"{test_fn_def}\n{value}" for test_fn_def, value in diff_per_test_fn.items() + ) + except Exception as e: + sentry_sdk.capture_exception(e) + logger.exception(e) + print(type(diff_per_test_fn), type(test_issues)) # with the parsed test results diff ask the llm to fix the candidate to match the test results of the original code, and run again # self.run_optimized_candidate( # optimization_candidate_index=optimization_candidate_index, From 142da4c748bf0734a5792d5abe5c1ff22bd5f86d Mon Sep 17 00:00:00 2001 From: Codeflash Bot Date: Sun, 30 Nov 2025 19:45:49 -0500 Subject: [PATCH 12/22] todo write backend endpoint --- codeflash/optimization/function_optimizer.py | 41 ++++++++++++++------ 1 file changed, 29 insertions(+), 12 deletions(-) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index a13e3058e..2371917b7 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -584,13 +584,34 @@ def determine_best_candidate( baseline_results=original_code_baseline, original_helper_code=original_helper_code, file_path_to_helper_classes=file_path_to_helper_classes, - code_context=code_context, ) console.rule() if not is_successful(run_results): optimized_runtimes[candidate.optimization_id] = None is_correct[candidate.optimization_id] = False speedup_ratios[candidate.optimization_id] = None + fail_value = run_results.value + if ( + fail_value != "Test results did not match the test results of the original code." + and len(future_all_refinements) <= 3 + and not candidate.optimization_id.endswith("cdrp") + ): + # # queue corresponding code repair optimization for best optimization + future_all_refinements.append( + self.code_repair_optimizations( + original_source_code=candidate, + modified_source_code=code_context, + original_code_baseline=original_code_baseline, + test_details="test_details", + code_context=code_context, + trace_id=self.function_trace_id[:-4] + exp_type + if self.experiment_id + else self.function_trace_id, + ai_service_client=ai_service_client, + executor=self.executor, + function_references=function_references, + ) + ) else: candidate_result: OptimizedCandidateResult = run_results.unwrap() best_test_runtime = candidate_result.best_test_runtime @@ -1831,12 +1852,15 @@ def run_optimized_candidate( diff_per_test_fn = {} for diff in diffs: try: - diff_per_test_fn.setdefault(diff.test_src_code, []).append( - f"Expected Value: {diff.original_value!s}\nActual Value: {diff.candidate_value!s}\nError String:{diff.pytest_error}" + diff_per_test_fn[diff.test_src_code] = ( + diff_per_test_fn.setdefault(diff.test_src_code, "") + + f"Expected Value: {diff.original_value!s}\nActual Value: {diff.candidate_value!s}\nError String:{diff.pytest_error}\n" ) + except Exception as e: sentry_sdk.capture_exception(e) logger.exception(e) + return self.get_results_not_matched_error() try: test_issues = "\n".join( f"{test_fn_def}\n{value}" for test_fn_def, value in diff_per_test_fn.items() @@ -1844,15 +1868,8 @@ def run_optimized_candidate( except Exception as e: sentry_sdk.capture_exception(e) logger.exception(e) - print(type(diff_per_test_fn), type(test_issues)) - # with the parsed test results diff ask the llm to fix the candidate to match the test results of the original code, and run again - # self.run_optimized_candidate( - # optimization_candidate_index=optimization_candidate_index, - # baseline_results=baseline_results, - # original_helper_code=original_helper_code, - # file_path_to_helper_classes=file_path_to_helper_classes, - # ) - return self.get_results_not_matched_error() + return self.get_results_not_matched_error() + return Failure(test_issues) logger.info(f"loading|Running performance tests for candidate {optimization_candidate_index}...") From 5a7c3563fba0b93fd74e05916b9bff9a62293259 Mon Sep 17 00:00:00 2001 From: Codeflash Bot Date: Sun, 30 Nov 2025 19:56:33 -0500 Subject: [PATCH 13/22] need to test now --- codeflash/api/aiservice.py | 55 +++++++++++++++++++- codeflash/models/models.py | 9 ++++ codeflash/optimization/function_optimizer.py | 21 ++++++++ 3 files changed, 84 insertions(+), 1 deletion(-) diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index 20c478eb4..eac035a13 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -27,7 +27,7 @@ from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.models.ExperimentMetadata import ExperimentMetadata - from codeflash.models.models import AIServiceRefinerRequest + from codeflash.models.models import AIServiceCodeRepairRequest, AIServiceRefinerRequest from codeflash.result.explanation import Explanation @@ -294,6 +294,59 @@ def optimize_python_code_refinement(self, request: list[AIServiceRefinerRequest] console.rule() return [] + def optimize_python_code_repair(self, request: list[AIServiceCodeRepairRequest]) -> list[OptimizedCandidate]: + """Optimize the given python code for performance by making a request to the Django endpoint. + + Args: + request: A list of optimization candidate details for refinement + + Returns: + ------- + - List[OptimizationCandidate]: A list of Optimization Candidates. + + """ + payload = [ + { + "optimization_id": opt.optimization_id, + "original_source_code": opt.original_source_code, + "modified_source_code": opt.modified_source_code, + "trace_id": opt.trace_id, + } + for opt in request + ] + # logger.debug(f"Repair {len(request)} optimizations…") + console.rule() + try: + response = self.make_ai_service_request("/code_repair", payload=payload, timeout=120) + except requests.exceptions.RequestException as e: + logger.exception(f"Error generating optimization repair: {e}") + ph("cli-optimize-error-caught", {"error": str(e)}) + return [] + + if response.status_code == 200: + refined_optimizations = response.json()["code_repairs"] + logger.debug(f"Generated {len(refined_optimizations)} candidate refinements.") + console.rule() + + refinements = self._get_valid_candidates(refined_optimizations) + return [ + OptimizedCandidate( + source_code=c.source_code, + explanation=c.explanation, + optimization_id=c.optimization_id[:-4] + "cdrp", + ) + for c in refinements + ] + + try: + error = response.json()["error"] + except Exception: + error = response.text + logger.error(f"Error generating optimized candidates: {response.status_code} - {error}") + ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error}) + console.rule() + return [] + def get_new_explanation( # noqa: D417 self, source_code: str, diff --git a/codeflash/models/models.py b/codeflash/models/models.py index 48ecf396a..66d952d01 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -48,6 +48,15 @@ class AIServiceRefinerRequest: function_references: str | None = None +@dataclass(frozen=True) +class AIServiceCodeRepairRequest: + optimization_id: str + original_source_code: str + modified_source_code: str + test_details: str + trace_id: str + + # If the method spam is in the class Ham, which is at the top level of the module eggs in the package foo, the fully # qualified name of the method is foo.eggs.Ham.spam, its qualified name is Ham.spam, and its name is spam. The full name # of the module is foo.eggs. diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 2371917b7..78b6305c9 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -70,6 +70,7 @@ from codeflash.lsp.lsp_message import LspCodeMessage, LspMarkdownMessage, LSPMessageId from codeflash.models.ExperimentMetadata import ExperimentMetadata from codeflash.models.models import ( + AIServiceCodeRepairRequest, BestOptimization, CodeOptimizationContext, GeneratedTests, @@ -862,6 +863,26 @@ def refine_optimizations( ] return executor.submit(ai_service_client.optimize_python_code_refinement, request=request) + def code_repair_optimizations( + self, + original_source_code: str, + modified_source_code: str, + test_details: str, + trace_id: str, + ai_service_client: AiServiceClient, + executor: concurrent.futures.ThreadPoolExecutor, + ) -> concurrent.futures.Future: + request = [ + AIServiceCodeRepairRequest( + optimization_id="", + original_source_code=original_source_code, + modified_source_code=modified_source_code, + test_details=test_details, + trace_id=trace_id, + ) + ] + return executor.submit(ai_service_client.optimize_python_code_repair, request=request) + def log_successful_optimization( self, explanation: Explanation, generated_tests: GeneratedTestsList, exp_type: str ) -> None: From 5ed5dfcfb542e8e82331f0794e5ea9caef39ab13 Mon Sep 17 00:00:00 2001 From: Codeflash Bot Date: Sun, 30 Nov 2025 22:18:09 -0500 Subject: [PATCH 14/22] works, figure out logging --- codeflash/api/aiservice.py | 3 ++- codeflash/optimization/function_optimizer.py | 15 +++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index eac035a13..afb529534 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -310,6 +310,7 @@ def optimize_python_code_repair(self, request: list[AIServiceCodeRepairRequest]) "optimization_id": opt.optimization_id, "original_source_code": opt.original_source_code, "modified_source_code": opt.modified_source_code, + "test_details": opt.test_details, "trace_id": opt.trace_id, } for opt in request @@ -325,7 +326,7 @@ def optimize_python_code_repair(self, request: list[AIServiceCodeRepairRequest]) if response.status_code == 200: refined_optimizations = response.json()["code_repairs"] - logger.debug(f"Generated {len(refined_optimizations)} candidate refinements.") + # logger.debug(f"Generated {len(refined_optimizations)} candidate refinements.") console.rule() refinements = self._get_valid_candidates(refined_optimizations) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 78b6305c9..51e788f96 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -593,24 +593,22 @@ def determine_best_candidate( speedup_ratios[candidate.optimization_id] = None fail_value = run_results.value if ( - fail_value != "Test results did not match the test results of the original code." + fail_value.strip() != "Test results did not match the test results of the original code." and len(future_all_refinements) <= 3 and not candidate.optimization_id.endswith("cdrp") ): # # queue corresponding code repair optimization for best optimization future_all_refinements.append( self.code_repair_optimizations( - original_source_code=candidate, - modified_source_code=code_context, - original_code_baseline=original_code_baseline, - test_details="test_details", - code_context=code_context, + original_source_code=code_context.read_writable_code.markdown, + modified_source_code=candidate.source_code.markdown, + test_details=fail_value, trace_id=self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id, ai_service_client=ai_service_client, executor=self.executor, - function_references=function_references, + optimization_id=candidate.optimization_id, ) ) else: @@ -869,12 +867,13 @@ def code_repair_optimizations( modified_source_code: str, test_details: str, trace_id: str, + optimization_id: str, ai_service_client: AiServiceClient, executor: concurrent.futures.ThreadPoolExecutor, ) -> concurrent.futures.Future: request = [ AIServiceCodeRepairRequest( - optimization_id="", + optimization_id=optimization_id, original_source_code=original_source_code, modified_source_code=modified_source_code, test_details=test_details, From fe33c8244036c5b4c119a31330977c19f3e1ec71 Mon Sep 17 00:00:00 2001 From: Codeflash Bot Date: Sun, 30 Nov 2025 23:53:36 -0500 Subject: [PATCH 15/22] local db logging --- codeflash/optimization/function_optimizer.py | 108 ++++++++++++++++++- 1 file changed, 107 insertions(+), 1 deletion(-) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 51e788f96..b100dd144 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -5,6 +5,7 @@ import os import queue import random +import sqlite3 import subprocess import time import uuid @@ -119,6 +120,61 @@ from codeflash.verification.verification_utils import TestConfig +def log_code_repair_to_db( + code_repair_log_db: Path, + optimization_id: str, + trace_id: str | None = None, + passed: str | None = None, + faster: str | None = None, +) -> None: + """Log code repair data to SQLite database. + + Uses upsert pattern to allow incremental logging with different columns at different places. + Only non-None values will be updated; existing values are preserved. + """ + try: + conn = sqlite3.connect(code_repair_log_db) + cursor = conn.cursor() + + # Build dynamic upsert query based on provided columns + columns = ["optimization_id"] + values = [optimization_id] + update_parts = ["updated_at = CURRENT_TIMESTAMP"] + + if trace_id is not None: + columns.append("trace_id") + values.append(trace_id) + update_parts.append("trace_id = excluded.trace_id") + + if passed is not None: + columns.append("passed") + values.append(passed) + update_parts.append("passed = excluded.passed") + + if faster is not None: + columns.append("faster") + values.append(faster) + update_parts.append("faster = excluded.faster") + + placeholders = ", ".join(["?"] * len(values)) + columns_str = ", ".join(columns) + update_str = ", ".join(update_parts) + + cursor.execute( + f""" + INSERT INTO code_repair_logs_cf ({columns_str}) + VALUES ({placeholders}) + ON CONFLICT(optimization_id) DO UPDATE SET {update_str} + """, # noqa: S608 + values, + ) + conn.commit() + conn.close() + except Exception as e: + sentry_sdk.capture_exception(e) + logger.exception(e) + + class CandidateProcessor: """Handles candidate processing using a queue-based approach.""" @@ -249,6 +305,8 @@ def __init__( max_workers=n_tests + 3 if self.experiment_id is None else n_tests + 4 ) self.optimization_review = "" + # SQLite database setup for logging + self.code_repair_log_db = Path(__file__).parent / "code_repair_log_cf.db" def can_be_optimized(self) -> Result[tuple[bool, CodeOptimizationContext, dict[Path, str]], str]: should_run_experiment = self.experiment_id is not None @@ -389,7 +447,19 @@ def optimize_function(self) -> Result[BestOptimization, str]: initialization_result = self.can_be_optimized() if not is_successful(initialization_result): return Failure(initialization_result.failure()) - + conn = sqlite3.connect(self.code_repair_log_db) + cursor = conn.cursor() + cursor.execute(""" + CREATE TABLE IF NOT EXISTS code_repair_logs ( + optimization_id TEXT PRIMARY KEY, + trace_id TEXT, + passed TEXT, + faster TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + """) + conn.commit() + conn.close() should_run_experiment, code_context, original_helper_code = initialization_result.unwrap() code_print( @@ -540,6 +610,14 @@ def determine_best_candidate( logger.warning( "force_lsp|No functions were replaced in the optimized code. Skipping optimization candidate." ) + if candidate.optimization_id.endswith("cdrp"): + log_code_repair_to_db( + code_repair_log_db=self.code_repair_log_db, + trace_id=self.function_trace_id[:-4] + exp_type, + optimization_id=candidate.optimization_id, + passed="no", + faster="no", # this also may or may not pass + ) console.rule() continue except (ValueError, SyntaxError, cst.ParserSyntaxError, AttributeError) as e: @@ -547,6 +625,14 @@ def determine_best_candidate( self.write_code_and_helpers( self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path ) + if candidate.optimization_id.endswith("cdrp"): + log_code_repair_to_db( + code_repair_log_db=self.code_repair_log_db, + trace_id=self.function_trace_id[:-4] + exp_type, + optimization_id=candidate.optimization_id, + passed="no", + faster="no", # this also may or may not pass + ) continue # check if this code has been evaluated before by checking the ast normalized code string normalized_code = normalize_code(candidate.source_code.flat.strip()) @@ -574,6 +660,16 @@ def determine_best_candidate( ): # new candidate has a shorter diff than the previously encountered one ast_code_to_id[normalized_code]["shorter_source_code"] = candidate.source_code ast_code_to_id[normalized_code]["diff_len"] = new_diff_len + if candidate.optimization_id.endswith("cdrp"): + log_code_repair_to_db( + code_repair_log_db=self.code_repair_log_db, + trace_id=self.function_trace_id[:-4] + exp_type, + optimization_id=candidate.optimization_id, + passed="yes" if is_correct[candidate.optimization_id] else "no", + faster="yes" + if speedup_ratios[candidate.optimization_id] > 0 + else "no", # this also may or may not pass + ) continue ast_code_to_id[normalized_code] = { "optimization_id": candidate.optimization_id, @@ -743,6 +839,16 @@ def determine_best_candidate( if self.args.benchmark and benchmark_tree: console.print(benchmark_tree) console.rule() + if candidate.optimization_id.endswith("cdrp"): + log_code_repair_to_db( + code_repair_log_db=self.code_repair_log_db, + trace_id=self.function_trace_id[:-4] + exp_type, + optimization_id=candidate.optimization_id, + passed="yes" if is_correct[candidate.optimization_id] else "no", + faster="yes" + if speedup_ratios[candidate.optimization_id] > 0 + else "no", # this also may or may not pass + ) except KeyboardInterrupt as e: logger.exception(f"Optimization interrupted: {e}") raise From 83814bee599af0696b20790b0617b2bb0acab0d0 Mon Sep 17 00:00:00 2001 From: Codeflash Bot Date: Mon, 1 Dec 2025 00:09:49 -0500 Subject: [PATCH 16/22] ready to run experiments --- codeflash/optimization/function_optimizer.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index b100dd144..4c6322fdf 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -306,7 +306,7 @@ def __init__( ) self.optimization_review = "" # SQLite database setup for logging - self.code_repair_log_db = Path(__file__).parent / "code_repair_log_cf.db" + self.code_repair_log_db = Path(__file__).parent / "code_repair_logs_cf.db" def can_be_optimized(self) -> Result[tuple[bool, CodeOptimizationContext, dict[Path, str]], str]: should_run_experiment = self.experiment_id is not None @@ -450,13 +450,14 @@ def optimize_function(self) -> Result[BestOptimization, str]: conn = sqlite3.connect(self.code_repair_log_db) cursor = conn.cursor() cursor.execute(""" - CREATE TABLE IF NOT EXISTS code_repair_logs ( + CREATE TABLE IF NOT EXISTS code_repair_logs_cf ( optimization_id TEXT PRIMARY KEY, trace_id TEXT, passed TEXT, faster TEXT, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) """) conn.commit() conn.close() @@ -1980,7 +1981,7 @@ def run_optimized_candidate( try: diff_per_test_fn[diff.test_src_code] = ( diff_per_test_fn.setdefault(diff.test_src_code, "") - + f"Expected Value: {diff.original_value!s}\nActual Value: {diff.candidate_value!s}\nError String:{diff.pytest_error}\n" + + f"Expected Value: {diff.original_value!s}\nActual Value: {diff.candidate_value!s}\nError String:{diff.candidate_pytest_error}\n" ) except Exception as e: From 0325444ce4626aea31ccaafe4650ddd1c9c3772e Mon Sep 17 00:00:00 2001 From: Codeflash Bot Date: Mon, 1 Dec 2025 00:20:09 -0500 Subject: [PATCH 17/22] logging fix --- codeflash/optimization/function_optimizer.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 4c6322fdf..2129da430 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -617,7 +617,7 @@ def determine_best_candidate( trace_id=self.function_trace_id[:-4] + exp_type, optimization_id=candidate.optimization_id, passed="no", - faster="no", # this also may or may not pass + faster="no", ) console.rule() continue @@ -632,7 +632,7 @@ def determine_best_candidate( trace_id=self.function_trace_id[:-4] + exp_type, optimization_id=candidate.optimization_id, passed="no", - faster="no", # this also may or may not pass + faster="no", ) continue # check if this code has been evaluated before by checking the ast normalized code string @@ -668,8 +668,11 @@ def determine_best_candidate( optimization_id=candidate.optimization_id, passed="yes" if is_correct[candidate.optimization_id] else "no", faster="yes" - if speedup_ratios[candidate.optimization_id] > 0 - else "no", # this also may or may not pass + if ( + speedup_ratios[candidate.optimization_id] is not None + and speedup_ratios[candidate.optimization_id] > 0 + ) + else "no", ) continue ast_code_to_id[normalized_code] = { @@ -847,8 +850,11 @@ def determine_best_candidate( optimization_id=candidate.optimization_id, passed="yes" if is_correct[candidate.optimization_id] else "no", faster="yes" - if speedup_ratios[candidate.optimization_id] > 0 - else "no", # this also may or may not pass + if ( + speedup_ratios[candidate.optimization_id] is not None + and speedup_ratios[candidate.optimization_id] > 0 + ) + else "no", ) except KeyboardInterrupt as e: logger.exception(f"Optimization interrupted: {e}") From 9f7ed9030c16f3c4e3bfcfb57e946d6b82419634 Mon Sep 17 00:00:00 2001 From: ali Date: Mon, 1 Dec 2025 16:48:23 +0200 Subject: [PATCH 18/22] handle test class methods for the test diff --- codeflash/models/models.py | 5 +++++ codeflash/verification/equivalence.py | 6 ++++-- tests/test_codeflash_capture.py | 1 - 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/codeflash/models/models.py b/codeflash/models/models.py index 66d952d01..39388630a 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -515,6 +515,11 @@ def id(self) -> str: f"{self.function_getting_tested}:{self.iteration_id}" ) + # TestSuiteClass.test_function_name + def test_fn_qualified_name(self) -> str: + class_prefix = f"{self.test_class_name}." if self.test_class_name else "" + return f"{class_prefix}{self.test_function_name}" + def find_func_in_class(self, class_node: cst.ClassDef, func_name: str) -> Optional[cst.FunctionDef]: for stmt in class_node.body.body: if isinstance(stmt, cst.FunctionDef) and stmt.name.value == func_name: diff --git a/codeflash/verification/equivalence.py b/codeflash/verification/equivalence.py index 642da9194..6eff438e4 100644 --- a/codeflash/verification/equivalence.py +++ b/codeflash/verification/equivalence.py @@ -53,12 +53,14 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR candidate_test_failures = candidate_results.test_failures original_test_failures = original_results.test_failures cdd_pytest_error = ( - candidate_test_failures.get(original_test_result.id.test_function_name, "") + candidate_test_failures.get(original_test_result.id.test_fn_qualified_name(), "") if candidate_test_failures else "" ) original_pytest_error = ( - original_test_failures.get(original_test_result.id.test_function_name, "") if original_test_failures else "" + original_test_failures.get(original_test_result.id.test_fn_qualified_name(), "") + if original_test_failures + else "" ) if cdd_test_result is not None and original_test_result is None: diff --git a/tests/test_codeflash_capture.py b/tests/test_codeflash_capture.py index bd6518b24..79133bc15 100644 --- a/tests/test_codeflash_capture.py +++ b/tests/test_codeflash_capture.py @@ -1406,7 +1406,6 @@ def risk_adjusted_return(return_val, weight): # Remove instrumentation FunctionOptimizer.write_code_and_helpers(candidate_fto_code, candidate_helper_code, fto.file_path) matched, diffs = compare_test_results(test_results, modified_test_results) - print(diffs) assert not matched From 6060ffbfe82d46864d2d9753fb62dfcd98e291ad Mon Sep 17 00:00:00 2001 From: mohammed ahmed <64513301+mohammedahmed18@users.noreply.github.com> Date: Tue, 2 Dec 2025 13:48:06 +0200 Subject: [PATCH 19/22] codeflash suggestion Co-authored-by: codeflash-ai[bot] <148906541+codeflash-ai[bot]@users.noreply.github.com> --- codeflash/models/models.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/codeflash/models/models.py b/codeflash/models/models.py index 39388630a..e4aa623d8 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -517,8 +517,12 @@ def id(self) -> str: # TestSuiteClass.test_function_name def test_fn_qualified_name(self) -> str: - class_prefix = f"{self.test_class_name}." if self.test_class_name else "" - return f"{class_prefix}{self.test_function_name}" + # Use f-string with inline conditional to reduce string concatenation operations + return ( + f"{self.test_class_name}.{self.test_function_name}" + if self.test_class_name + else str(self.test_function_name) + ) def find_func_in_class(self, class_node: cst.ClassDef, func_name: str) -> Optional[cst.FunctionDef]: for stmt in class_node.body.body: From 1120d6416c378060aaa2077e28bc112925f17df0 Mon Sep 17 00:00:00 2001 From: ali Date: Tue, 2 Dec 2025 13:59:08 +0200 Subject: [PATCH 20/22] safer parsing --- codeflash/verification/parse_test_output.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/codeflash/verification/parse_test_output.py b/codeflash/verification/parse_test_output.py index 873f9f341..25fcf4e63 100644 --- a/codeflash/verification/parse_test_output.py +++ b/codeflash/verification/parse_test_output.py @@ -517,16 +517,18 @@ def parse_test_failures_from_stdout(test_results: TestResults, stdout: str) -> T start_line = end_line = None for i, line in enumerate(stdout_lines): - if start_line is None and "FAILURES" in line: + stripped_line = line.strip() + if start_line is None and stripped_line[0] == "=" and "FAILURES" in stripped_line: start_line = i - elif start_line is not None and end_line is None and "short test summary info" in line: + # exclude last summary line + elif start_line is not None and end_line is None and "short test summary info" in stripped_line: end_line = i break if start_line is None or end_line is None: return test_results - complete_failure_output_lines = stdout_lines[start_line:end_line] # exclude last summary line + complete_failure_output_lines = stdout_lines[start_line:end_line] test_case_to_failure: dict[str, str] = {} From c2e037aa4321f9b6d808a1880acb74be435767ee Mon Sep 17 00:00:00 2001 From: ali Date: Tue, 2 Dec 2025 14:09:00 +0200 Subject: [PATCH 21/22] better parsing for pytest stdout --- codeflash/verification/parse_test_output.py | 68 +++++++++++++-------- 1 file changed, 41 insertions(+), 27 deletions(-) diff --git a/codeflash/verification/parse_test_output.py b/codeflash/verification/parse_test_output.py index 25fcf4e63..f5cdad9d1 100644 --- a/codeflash/verification/parse_test_output.py +++ b/codeflash/verification/parse_test_output.py @@ -512,44 +512,58 @@ def merge_test_results( return merged_test_results +FAILURES_HEADER_RE = re.compile(r"=+ FAILURES =+") +TEST_HEADER_RE = re.compile(r"_{3,}\s*(.*?)\s*_{3,}$") + + def parse_test_failures_from_stdout(test_results: TestResults, stdout: str) -> TestResults: - stdout_lines = stdout.splitlines() - start_line = end_line = None - - for i, line in enumerate(stdout_lines): - stripped_line = line.strip() - if start_line is None and stripped_line[0] == "=" and "FAILURES" in stripped_line: - start_line = i - # exclude last summary line - elif start_line is not None and end_line is None and "short test summary info" in stripped_line: - end_line = i + """Extract individual pytest test failures from stdout grouped by test case qualified name, and add them to the test results.""" + lines = stdout.splitlines() + start = end = None + + for i, line in enumerate(lines): + if FAILURES_HEADER_RE.search(line.strip()): + start = i break - if start_line is None or end_line is None: + if start is None: return test_results - complete_failure_output_lines = stdout_lines[start_line:end_line] + for j in range(start + 1, len(lines)): + stripped = lines[j].strip() + if "short test summary info" in stripped: + end = j + break + # any new === section === block + if stripped.startswith("=") and stripped.count("=") > 3: + end = j + break + + # If no clear "end", just grap the rest of the string + if end is None: + end = len(lines) - test_case_to_failure: dict[str, str] = {} + failure_block = lines[start:end] - current_test_case: str | None = None - current_failure_lines: list[str] = [] + failures: dict[str, str] = {} + current_name = None + current_lines: list[str] = [] - underline_prefix = "_______" + for line in failure_block: + m = TEST_HEADER_RE.match(line.strip()) + if m: + if current_name is not None: + failures[current_name] = "".join(current_lines) - for line in complete_failure_output_lines: - if line and line[0] == "_" and line.startswith(underline_prefix): - if current_test_case: - test_case_to_failure[current_test_case] = "".join(current_failure_lines) - current_test_case = line.strip("_ ").strip() - current_failure_lines.clear() - elif current_test_case: - current_failure_lines.append(line + "\n") + current_name = m.group(1) + current_lines = [] + elif current_name: + current_lines.append(line + "\n") - if current_test_case: - test_case_to_failure[current_test_case] = "".join(current_failure_lines) + if current_name: + failures[current_name] = "".join(current_lines) - test_results.test_failures = test_case_to_failure + test_results.test_failures = failures return test_results From bd1ebf4b76dc149919b2e35877f4adb1274c7287 Mon Sep 17 00:00:00 2001 From: Codeflash Bot Date: Tue, 2 Dec 2025 22:52:34 -0500 Subject: [PATCH 22/22] temp logging --- codeflash/optimization/function_optimizer.py | 84 +++++--------------- 1 file changed, 22 insertions(+), 62 deletions(-) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 2129da430..560c54fd1 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -121,58 +121,32 @@ def log_code_repair_to_db( - code_repair_log_db: Path, - optimization_id: str, - trace_id: str | None = None, - passed: str | None = None, - faster: str | None = None, + code_repair_log_db: Path, optimization_id: str, trace_id: str, passed: str, faster: str ) -> None: - """Log code repair data to SQLite database. - - Uses upsert pattern to allow incremental logging with different columns at different places. - Only non-None values will be updated; existing values are preserved. - """ + """Log code repair data to SQLite database.""" try: - conn = sqlite3.connect(code_repair_log_db) - cursor = conn.cursor() - - # Build dynamic upsert query based on provided columns - columns = ["optimization_id"] - values = [optimization_id] - update_parts = ["updated_at = CURRENT_TIMESTAMP"] - - if trace_id is not None: - columns.append("trace_id") - values.append(trace_id) - update_parts.append("trace_id = excluded.trace_id") - - if passed is not None: - columns.append("passed") - values.append(passed) - update_parts.append("passed = excluded.passed") - - if faster is not None: - columns.append("faster") - values.append(faster) - update_parts.append("faster = excluded.faster") - - placeholders = ", ".join(["?"] * len(values)) - columns_str = ", ".join(columns) - update_str = ", ".join(update_parts) - - cursor.execute( - f""" - INSERT INTO code_repair_logs_cf ({columns_str}) - VALUES ({placeholders}) - ON CONFLICT(optimization_id) DO UPDATE SET {update_str} - """, # noqa: S608 - values, - ) - conn.commit() - conn.close() + with sqlite3.connect(code_repair_log_db) as conn: + cursor = conn.cursor() + cursor.execute(""" + CREATE TABLE IF NOT EXISTS code_repair_logs_cf ( + optimization_id TEXT PRIMARY KEY, + trace_id TEXT, + passed TEXT, + faster TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """) + cursor.execute( + """ + INSERT INTO code_repair_logs_cf (optimization_id, trace_id, passed, faster) + VALUES (?, ?, ?, ?) + """, + (optimization_id, trace_id, passed, faster), + ) + conn.commit() except Exception as e: sentry_sdk.capture_exception(e) - logger.exception(e) + logger.exception("Error logging code repair to db") class CandidateProcessor: @@ -447,20 +421,6 @@ def optimize_function(self) -> Result[BestOptimization, str]: initialization_result = self.can_be_optimized() if not is_successful(initialization_result): return Failure(initialization_result.failure()) - conn = sqlite3.connect(self.code_repair_log_db) - cursor = conn.cursor() - cursor.execute(""" - CREATE TABLE IF NOT EXISTS code_repair_logs_cf ( - optimization_id TEXT PRIMARY KEY, - trace_id TEXT, - passed TEXT, - faster TEXT, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP - ) - """) - conn.commit() - conn.close() should_run_experiment, code_context, original_helper_code = initialization_result.unwrap() code_print(