diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index 20c478eb4..8b59fab9a 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -27,7 +27,7 @@ from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.models.ExperimentMetadata import ExperimentMetadata - from codeflash.models.models import AIServiceRefinerRequest + from codeflash.models.models import AIServiceCodeRepairRequest, AIServiceRefinerRequest from codeflash.result.explanation import Explanation @@ -294,6 +294,55 @@ def optimize_python_code_refinement(self, request: list[AIServiceRefinerRequest] console.rule() return [] + def optimize_python_code_repair(self, request: AIServiceCodeRepairRequest) -> OptimizedCandidate | None: + """Optimize the given python code for performance by making a request to the Django endpoint. + + Args: + request: optimization candidate details for refinement + + Returns: + ------- + - OptimizationCandidate: new fixed candidate. + + """ + console.rule() + try: + payload = { + "optimization_id": request.optimization_id, + "original_source_code": request.original_source_code, + "modified_source_code": request.modified_source_code, + "trace_id": request.trace_id, + "test_diffs": request.test_diffs, + } + response = self.make_ai_service_request("/code_repair", payload=payload, timeout=120) + except (requests.exceptions.RequestException, TypeError) as e: + logger.exception(f"Error generating optimization repair: {e}") + ph("cli-optimize-error-caught", {"error": str(e)}) + return None + + if response.status_code == 200: + fixed_optimization = response.json() + console.rule() + + if not self._get_valid_candidates([fixed_optimization]): + logger.error("Code repair failed to generate a valid candidate.") + return None + + return OptimizedCandidate( + source_code=fixed_optimization["source_code"], + explanation=fixed_optimization["explanation"], + optimization_id=fixed_optimization["optimization_id"], + ) + + try: + error = response.json()["error"] + except Exception: + error = response.text + logger.error(f"Error generating optimized candidates: {response.status_code} - {error}") + ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error}) + console.rule() + return None + def get_new_explanation( # noqa: D417 self, source_code: str, diff --git a/codeflash/models/models.py b/codeflash/models/models.py index 744f76087..52f8ba4c8 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -3,6 +3,7 @@ from collections import Counter, defaultdict from typing import TYPE_CHECKING +import libcst as cst from rich.tree import Tree from codeflash.cli_cmds.console import DEBUG_MODE, lsp_log @@ -47,6 +48,34 @@ class AIServiceRefinerRequest: function_references: str | None = None +class TestDiffScope(str, Enum): + RETURN_VALUE = "return_value" + STDOUT = "stdout" + DID_PASS = "did_pass" # noqa: S105 + + +@dataclass +class TestDiff: + scope: TestDiffScope + original_pass: bool + candidate_pass: bool + + original_value: str | None = None + candidate_value: str | None = None + test_src_code: Optional[str] = None + candidate_pytest_error: Optional[str] = None + original_pytest_error: Optional[str] = None + + +@dataclass(frozen=True) +class AIServiceCodeRepairRequest: + optimization_id: str + original_source_code: str + modified_source_code: str + trace_id: str + test_diffs: list[TestDiff] + + # If the method spam is in the class Ham, which is at the top level of the module eggs in the package foo, the fully # qualified name of the method is foo.eggs.Ham.spam, its qualified name is Ham.spam, and its name is spam. The full name # of the module is foo.eggs. @@ -505,6 +534,42 @@ def id(self) -> str: f"{self.function_getting_tested}:{self.iteration_id}" ) + # TestSuiteClass.test_function_name + def test_fn_qualified_name(self) -> str: + # Use f-string with inline conditional to reduce string concatenation operations + return ( + f"{self.test_class_name}.{self.test_function_name}" + if self.test_class_name + else str(self.test_function_name) + ) + + def find_func_in_class(self, class_node: cst.ClassDef, func_name: str) -> Optional[cst.FunctionDef]: + for stmt in class_node.body.body: + if isinstance(stmt, cst.FunctionDef) and stmt.name.value == func_name: + return stmt + return None + + def get_src_code(self, test_path: Path) -> Optional[str]: + if not test_path.exists(): + return None + test_src = test_path.read_text(encoding="utf-8") + module_node = cst.parse_module(test_src) + + if self.test_class_name: + for stmt in module_node.body: + if isinstance(stmt, cst.ClassDef) and stmt.name.value == self.test_class_name: + func_node = self.find_func_in_class(stmt, self.test_function_name) + if func_node: + return module_node.code_for_node(func_node).strip() + # class not found + return None + + # Otherwise, look for a top level function + for stmt in module_node.body: + if isinstance(stmt, cst.FunctionDef) and stmt.name.value == self.test_function_name: + return module_node.code_for_node(stmt).strip() + return None + @staticmethod def from_str_id(string_id: str, iteration_id: str | None = None) -> InvocationId: components = string_id.split(":") @@ -549,7 +614,10 @@ class TestResults(BaseModel): # noqa: PLW1641 # also we don't support deletion of test results elements - caution is advised test_results: list[FunctionTestInvocation] = [] test_result_idx: dict[str, int] = {} + perf_stdout: Optional[str] = None + # mapping between test function name and stdout failure message + test_failures: Optional[dict[str, str]] = None def add(self, function_test_invocation: FunctionTestInvocation) -> None: unique_id = function_test_invocation.unique_invocation_loop_id diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 2eef51f0f..11139bc53 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -69,6 +69,7 @@ from codeflash.lsp.lsp_message import LspCodeMessage, LspMarkdownMessage, LSPMessageId from codeflash.models.ExperimentMetadata import ExperimentMetadata from codeflash.models.models import ( + AIServiceCodeRepairRequest, BestOptimization, CodeOptimizationContext, GeneratedTests, @@ -113,6 +114,7 @@ CoverageData, FunctionCalledInTest, FunctionSource, + TestDiff, ) from codeflash.verification.verification_utils import TestConfig @@ -247,6 +249,7 @@ def __init__( max_workers=n_tests + 3 if self.experiment_id is None else n_tests + 4 ) self.optimization_review = "" + self.ast_code_to_id = {} def can_be_optimized(self) -> Result[tuple[bool, CodeOptimizationContext, dict[Path, str]], str]: should_run_experiment = self.experiment_id is not None @@ -387,7 +390,6 @@ def optimize_function(self) -> Result[BestOptimization, str]: initialization_result = self.can_be_optimized() if not is_successful(initialization_result): return Failure(initialization_result.failure()) - should_run_experiment, code_context, original_helper_code = initialization_result.unwrap() code_print( @@ -459,6 +461,48 @@ def optimize_function(self) -> Result[BestOptimization, str]: return Failure(f"No best optimizations found for function {self.function_to_optimize.qualified_name}") return Success(best_optimization) + def reset_optimization_metrics_for_candidate( + self, opt_id: str, speedup_ratios: dict, is_correct: dict, optimized_runtimes: dict + ) -> None: + speedup_ratios[opt_id] = None + is_correct[opt_id] = False + optimized_runtimes[opt_id] = None + + def was_candidate_tested_before(self, normalized_code: str) -> bool: + # check if this code has been evaluated before by checking the ast normalized code string + return normalized_code in self.ast_code_to_id + + def update_results_for_duplicate_candidate( + self, + candidate: OptimizedCandidate, + code_context: CodeOptimizationContext, + normalized_code: str, + speedup_ratios: dict, + is_correct: dict, + optimized_runtimes: dict, + optimized_line_profiler_results: dict, + optimizations_post: dict, + ) -> None: + logger.info("Current candidate has been encountered before in testing, Skipping optimization candidate.") + past_opt_id = self.ast_code_to_id[normalized_code]["optimization_id"] + # update speedup ratio, is_correct, optimizations_post, optimized_line_profiler_results, optimized_runtimes + speedup_ratios[candidate.optimization_id] = speedup_ratios[past_opt_id] + is_correct[candidate.optimization_id] = is_correct[past_opt_id] + optimized_runtimes[candidate.optimization_id] = optimized_runtimes[past_opt_id] + # line profiler results only available for successful runs + if past_opt_id in optimized_line_profiler_results: + optimized_line_profiler_results[candidate.optimization_id] = optimized_line_profiler_results[past_opt_id] + optimizations_post[candidate.optimization_id] = self.ast_code_to_id[normalized_code][ + "shorter_source_code" + ].markdown + optimizations_post[past_opt_id] = self.ast_code_to_id[normalized_code]["shorter_source_code"].markdown + new_diff_len = diff_length(candidate.source_code.flat, code_context.read_writable_code.flat) + if ( + new_diff_len < self.ast_code_to_id[normalized_code]["diff_len"] + ): # new candidate has a shorter diff than the previously encountered one + self.ast_code_to_id[normalized_code]["shorter_source_code"] = candidate.source_code + self.ast_code_to_id[normalized_code]["diff_len"] = new_diff_len + def determine_best_candidate( self, *, @@ -485,7 +529,7 @@ def determine_best_candidate( console.rule() future_all_refinements: list[concurrent.futures.Future] = [] - ast_code_to_id = {} + self.ast_code_to_id.clear() valid_optimizations = [] optimizations_post = {} # we need to overwrite some opt candidates' code strings as they are no longer evaluated, instead their shorter/longer versions might be evaluated @@ -548,47 +592,44 @@ def determine_best_candidate( continue # check if this code has been evaluated before by checking the ast normalized code string normalized_code = normalize_code(candidate.source_code.flat.strip()) - if normalized_code in ast_code_to_id: - logger.info( - "Current candidate has been encountered before in testing, Skipping optimization candidate." + if self.was_candidate_tested_before(normalized_code): + self.update_results_for_duplicate_candidate( + candidate=candidate, + code_context=code_context, + normalized_code=normalized_code, + speedup_ratios=speedup_ratios, + is_correct=is_correct, + optimized_runtimes=optimized_runtimes, + optimized_line_profiler_results=optimized_line_profiler_results, + optimizations_post=optimizations_post, ) - past_opt_id = ast_code_to_id[normalized_code]["optimization_id"] - # update speedup ratio, is_correct, optimizations_post, optimized_line_profiler_results, optimized_runtimes - speedup_ratios[candidate.optimization_id] = speedup_ratios[past_opt_id] - is_correct[candidate.optimization_id] = is_correct[past_opt_id] - optimized_runtimes[candidate.optimization_id] = optimized_runtimes[past_opt_id] - # line profiler results only available for successful runs - if past_opt_id in optimized_line_profiler_results: - optimized_line_profiler_results[candidate.optimization_id] = optimized_line_profiler_results[ - past_opt_id - ] - optimizations_post[candidate.optimization_id] = ast_code_to_id[normalized_code][ - "shorter_source_code" - ].markdown - optimizations_post[past_opt_id] = ast_code_to_id[normalized_code]["shorter_source_code"].markdown - new_diff_len = diff_length(candidate.source_code.flat, code_context.read_writable_code.flat) - if ( - new_diff_len < ast_code_to_id[normalized_code]["diff_len"] - ): # new candidate has a shorter diff than the previously encountered one - ast_code_to_id[normalized_code]["shorter_source_code"] = candidate.source_code - ast_code_to_id[normalized_code]["diff_len"] = new_diff_len continue - ast_code_to_id[normalized_code] = { + self.ast_code_to_id[normalized_code] = { "optimization_id": candidate.optimization_id, "shorter_source_code": candidate.source_code, "diff_len": diff_length(candidate.source_code.flat, code_context.read_writable_code.flat), } - run_results = self.run_optimized_candidate( + self.reset_optimization_metrics_for_candidate( + candidate.optimization_id, speedup_ratios, is_correct, optimized_runtimes + ) + run_results, new_candidate = self.run_optimized_candidate( optimization_candidate_index=candidate_index, baseline_results=original_code_baseline, original_helper_code=original_helper_code, file_path_to_helper_classes=file_path_to_helper_classes, + code_context=code_context, + candidate=candidate, + exp_type=exp_type, ) + if candidate.optimization_id != new_candidate.optimization_id: + # override the candidate if the optimization_id has changed, this may happen if the candidate was modified by the code-repair + candidate = new_candidate + console.rule() if not is_successful(run_results): - optimized_runtimes[candidate.optimization_id] = None - is_correct[candidate.optimization_id] = False - speedup_ratios[candidate.optimization_id] = None + self.reset_optimization_metrics_for_candidate( + candidate.optimization_id, speedup_ratios, is_correct, optimized_runtimes + ) else: candidate_result: OptimizedCandidateResult = run_results.unwrap() best_test_runtime = candidate_result.best_test_runtime @@ -742,7 +783,7 @@ def determine_best_candidate( for valid_opt in valid_optimizations: valid_opt_normalized_code = normalize_code(valid_opt.candidate.source_code.flat.strip()) new_candidate_with_shorter_code = OptimizedCandidate( - source_code=ast_code_to_id[valid_opt_normalized_code]["shorter_source_code"], + source_code=self.ast_code_to_id[valid_opt_normalized_code]["shorter_source_code"], optimization_id=valid_opt.candidate.optimization_id, explanation=valid_opt.candidate.explanation, ) @@ -839,6 +880,24 @@ def refine_optimizations( ] return executor.submit(ai_service_client.optimize_python_code_refinement, request=request) + def code_repair_optimizations( + self, + original_source_code: str, + modified_source_code: str, + test_diffs: list[TestDiff], + trace_id: str, + optimization_id: str, + ai_service_client: AiServiceClient, + ) -> OptimizedCandidate | None: + request = AIServiceCodeRepairRequest( + optimization_id=optimization_id, + original_source_code=original_source_code, + modified_source_code=modified_source_code, + test_diffs=test_diffs, + trace_id=trace_id, + ) + return ai_service_client.optimize_python_code_repair(request=request) + def log_successful_optimization( self, explanation: Explanation, generated_tests: GeneratedTestsList, exp_type: str ) -> None: @@ -1752,14 +1811,22 @@ def establish_original_code_baseline( ) ) - def run_optimized_candidate( + def get_results_not_matched_error(self) -> Failure: + logger.info("h4|Test results did not match the test results of the original code ❌") + console.rule() + return Failure("Test results did not match the test results of the original code.") + + def run_optimized_candidate( # noqa: PLR0911 self, *, optimization_candidate_index: int, baseline_results: OriginalCodeBaseline, original_helper_code: dict[Path, str], file_path_to_helper_classes: dict[Path, set[str]], - ) -> Result[OptimizedCandidateResult, str]: + code_context: CodeOptimizationContext, + candidate: OptimizedCandidate, + exp_type: str, + ) -> tuple[Result[OptimizedCandidateResult, str], OptimizedCandidate]: assert (test_framework := self.args.test_framework) in {"pytest", "unittest"} # noqa: RUF018 with progress_bar("Testing optimization candidate"): @@ -1808,13 +1875,78 @@ def run_optimized_candidate( ) ) console.rule() - if compare_test_results(baseline_results.behavior_test_results, candidate_behavior_results): + # print(type(code_context), type(candidate)) + match, diffs = compare_test_results(baseline_results.behavior_test_results, candidate_behavior_results) + if match: logger.info("h3|Test results matched ✅") console.rule() else: - logger.info("h4|Test results did not match the test results of the original code ❌") - console.rule() - return Failure("Test results did not match the test results of the original code.") + result_unmatched_perc = len(diffs) / len(candidate_behavior_results) + if result_unmatched_perc > 0.5: + # if the test unmatched percentage is greater than 50%, we can't fix it + return self.get_results_not_matched_error(), candidate + + if candidate.optimization_id.endswith("cdrp"): + # prevent looping for now + return self.get_results_not_matched_error(), candidate + + ai_service_client = self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client + + with progress_bar("Some of the test results are not matching, let me see if I can fix this"): + new_candidate = self.code_repair_optimizations( + original_source_code=code_context.read_writable_code.markdown, + modified_source_code=candidate.source_code.markdown, + test_diffs=diffs, + trace_id=self.function_trace_id[:-4] + exp_type + if self.experiment_id + else self.function_trace_id, + ai_service_client=ai_service_client, + optimization_id=candidate.optimization_id, + ) + if not new_candidate: + return Failure("Code repair failed to generate a valid candidate."), candidate + + code_print( + new_candidate.source_code.flat, + file_name=f"candidate_{optimization_candidate_index}.py", + function_name=self.function_to_optimize.function_name, + ) + normalized_code = normalize_code(new_candidate.source_code.flat.strip()) + self.ast_code_to_id[normalized_code] = { + "optimization_id": new_candidate.optimization_id, + "shorter_source_code": new_candidate.source_code, + "diff_len": diff_length(new_candidate.source_code.flat, code_context.read_writable_code.flat), + } + + try: + # revert first to original code then replace with new repaired code, so we don't get any weird behavior + self.write_code_and_helpers( + self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path + ) + did_update = self.replace_function_and_helpers_with_optimized_code( + code_context=code_context, + optimized_code=new_candidate.source_code, + original_helper_code=original_helper_code, + ) + if did_update: + return self.run_optimized_candidate( + optimization_candidate_index=optimization_candidate_index, + baseline_results=baseline_results, + original_helper_code=original_helper_code, + file_path_to_helper_classes=file_path_to_helper_classes, + code_context=code_context, + candidate=new_candidate, + exp_type=exp_type, + ) + msg = "No functions were replaced in the optimized code. Skipping optimization candidate." + logger.warning(f"force_lsp|{msg}") + return Failure(msg), candidate + except (ValueError, SyntaxError, cst.ParserSyntaxError, AttributeError) as e: + logger.error(e) + self.write_code_and_helpers( + self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path + ) + return Failure("Code repair failed to generate a valid candidate."), candidate logger.info(f"loading|Running performance tests for candidate {optimization_candidate_index}...") @@ -1906,7 +2038,7 @@ def run_optimized_candidate( total_candidate_timing=total_candidate_timing, async_throughput=candidate_async_throughput, ) - ) + ), candidate def run_and_parse_tests( self, diff --git a/codeflash/verification/equivalence.py b/codeflash/verification/equivalence.py index 9d7f5ba2c..545616bae 100644 --- a/codeflash/verification/equivalence.py +++ b/codeflash/verification/equivalence.py @@ -1,27 +1,46 @@ +from __future__ import annotations + import sys +from typing import TYPE_CHECKING from codeflash.cli_cmds.console import logger -from codeflash.models.models import TestResults, TestType, VerificationType +from codeflash.models.models import TestDiff, TestDiffScope, TestResults, TestType, VerificationType from codeflash.verification.comparator import comparator +if TYPE_CHECKING: + from codeflash.models.models import TestResults + INCREASED_RECURSION_LIMIT = 5000 -def compare_test_results(original_results: TestResults, candidate_results: TestResults) -> bool: +def compare_test_results(original_results: TestResults, candidate_results: TestResults) -> tuple[bool, list[TestDiff]]: # This is meant to be only called with test results for the first loop index if len(original_results) == 0 or len(candidate_results) == 0: - return False # empty test results are not equal + return False, [] # empty test results are not equal original_recursion_limit = sys.getrecursionlimit() if original_recursion_limit < INCREASED_RECURSION_LIMIT: sys.setrecursionlimit(INCREASED_RECURSION_LIMIT) # Increase recursion limit to avoid RecursionError test_ids_superset = original_results.get_all_unique_invocation_loop_ids().union( set(candidate_results.get_all_unique_invocation_loop_ids()) ) - are_equal: bool = True + test_diffs: list[TestDiff] = [] did_all_timeout: bool = True for test_id in test_ids_superset: original_test_result = original_results.get_by_unique_invocation_loop_id(test_id) cdd_test_result = candidate_results.get_by_unique_invocation_loop_id(test_id) + candidate_test_failures = candidate_results.test_failures + original_test_failures = original_results.test_failures + cdd_pytest_error = ( + candidate_test_failures.get(original_test_result.id.test_fn_qualified_name(), "") + if candidate_test_failures + else "" + ) + original_pytest_error = ( + original_test_failures.get(original_test_result.id.test_fn_qualified_name(), "") + if original_test_failures + else "" + ) + if cdd_test_result is not None and original_test_result is None: continue # If helper function instance_state verification is not present, that's ok. continue @@ -32,8 +51,7 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR ): continue if original_test_result is None or cdd_test_result is None: - are_equal = False - break + return False, [] did_all_timeout = did_all_timeout and original_test_result.timed_out if original_test_result.timed_out: continue @@ -43,42 +61,53 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR in {VerificationType.INIT_STATE_HELPER, VerificationType.INIT_STATE_FTO} ): superset_obj = True + + test_src_code = original_test_result.id.get_src_code(original_test_result.file_name) + test_diff = TestDiff( + scope=TestDiffScope.RETURN_VALUE, + original_value=repr(original_test_result.return_value), + candidate_value=repr(cdd_test_result.return_value), + test_src_code=test_src_code, + candidate_pytest_error=cdd_pytest_error, + original_pass=original_test_result.did_pass, + candidate_pass=cdd_test_result.did_pass, + original_pytest_error=original_pytest_error, + ) if not comparator(original_test_result.return_value, cdd_test_result.return_value, superset_obj=superset_obj): - are_equal = False + test_diff.scope = TestDiffScope.RETURN_VALUE + test_diffs.append(test_diff) + try: logger.debug( - "File Name: %s\n" - "Test Type: %s\n" - "Verification Type: %s\n" - "Invocation ID: %s\n" - "Original return value: %s\n" - "Candidate return value: %s\n" - "-------------------", - original_test_result.file_name, - original_test_result.test_type, - original_test_result.verification_type, - original_test_result.id, - original_test_result.return_value, - cdd_test_result.return_value, + f"File Name: {original_test_result.file_name}\n" + f"Test Type: {original_test_result.test_type}\n" + f"Verification Type: {original_test_result.verification_type}\n" + f"Invocation ID: {original_test_result.id}\n" + f"Original return value: {original_test_result.return_value}\n" + f"Candidate return value: {cdd_test_result.return_value}\n" ) except Exception as e: logger.error(e) - break - if (original_test_result.stdout and cdd_test_result.stdout) and not comparator( + elif (original_test_result.stdout and cdd_test_result.stdout) and not comparator( original_test_result.stdout, cdd_test_result.stdout ): - are_equal = False - break + test_diff.scope = TestDiffScope.STDOUT + test_diff.original_value = str(original_test_result.stdout) + test_diff.candidate_value = str(cdd_test_result.stdout) + test_diffs.append(test_diff) - if original_test_result.test_type in { + elif original_test_result.test_type in { TestType.EXISTING_UNIT_TEST, TestType.CONCOLIC_COVERAGE_TEST, TestType.GENERATED_REGRESSION, TestType.REPLAY_TEST, } and (cdd_test_result.did_pass != original_test_result.did_pass): - are_equal = False - break + test_diff.scope = TestDiffScope.DID_PASS + test_diff.original_value = str(original_test_result.did_pass) + test_diff.candidate_value = str(cdd_test_result.did_pass) + test_diffs.append(test_diff) + sys.setrecursionlimit(original_recursion_limit) if did_all_timeout: - return False - return are_equal + return False, test_diffs + return len(test_diffs) == 0, test_diffs diff --git a/codeflash/verification/parse_test_output.py b/codeflash/verification/parse_test_output.py index ef513a0a3..f5cdad9d1 100644 --- a/codeflash/verification/parse_test_output.py +++ b/codeflash/verification/parse_test_output.py @@ -512,6 +512,61 @@ def merge_test_results( return merged_test_results +FAILURES_HEADER_RE = re.compile(r"=+ FAILURES =+") +TEST_HEADER_RE = re.compile(r"_{3,}\s*(.*?)\s*_{3,}$") + + +def parse_test_failures_from_stdout(test_results: TestResults, stdout: str) -> TestResults: + """Extract individual pytest test failures from stdout grouped by test case qualified name, and add them to the test results.""" + lines = stdout.splitlines() + start = end = None + + for i, line in enumerate(lines): + if FAILURES_HEADER_RE.search(line.strip()): + start = i + break + + if start is None: + return test_results + + for j in range(start + 1, len(lines)): + stripped = lines[j].strip() + if "short test summary info" in stripped: + end = j + break + # any new === section === block + if stripped.startswith("=") and stripped.count("=") > 3: + end = j + break + + # If no clear "end", just grap the rest of the string + if end is None: + end = len(lines) + + failure_block = lines[start:end] + + failures: dict[str, str] = {} + current_name = None + current_lines: list[str] = [] + + for line in failure_block: + m = TEST_HEADER_RE.match(line.strip()) + if m: + if current_name is not None: + failures[current_name] = "".join(current_lines) + + current_name = m.group(1) + current_lines = [] + elif current_name: + current_lines.append(line + "\n") + + if current_name: + failures[current_name] = "".join(current_lines) + + test_results.test_failures = failures + return test_results + + def parse_test_results( test_xml_path: Path, test_files: TestFiles, @@ -572,4 +627,9 @@ def parse_test_results( function_name=function_name, ) coverage.log_coverage() + try: + parse_test_failures_from_stdout(results, run_result.stdout) + except Exception as e: + logger.exception(e) + return results, coverage if all_args else None diff --git a/tests/test_codeflash_capture.py b/tests/test_codeflash_capture.py index c326cecc4..79133bc15 100644 --- a/tests/test_codeflash_capture.py +++ b/tests/test_codeflash_capture.py @@ -502,7 +502,8 @@ def __init__(self, x=2): pytest_max_loops=1, testing_time=0.1, ) - assert compare_test_results(test_results, test_results2) + match, _ = compare_test_results(test_results, test_results2) + assert match finally: test_path.unlink(missing_ok=True) @@ -626,7 +627,8 @@ def __init__(self, *args, **kwargs): testing_time=0.1, ) - assert compare_test_results(test_results, results2) + match, _ = compare_test_results(test_results, results2) + assert match finally: test_path.unlink(missing_ok=True) @@ -754,7 +756,8 @@ def __init__(self, x=2): testing_time=0.1, ) - assert compare_test_results(test_results, test_results2) + match, _ = compare_test_results(test_results, test_results2) + assert match finally: test_path.unlink(missing_ok=True) sample_code_path.unlink(missing_ok=True) @@ -902,7 +905,8 @@ def another_helper(self): testing_time=0.1, ) - assert compare_test_results(test_results, results2) + match, _ = compare_test_results(test_results, results2) + assert match finally: test_path.unlink(missing_ok=True) @@ -1132,7 +1136,8 @@ def target_function(self): ) # Remove instrumentation FunctionOptimizer.write_code_and_helpers(candidate_fto_code, candidate_helper_code, fto.file_path) - assert not compare_test_results(test_results, mutated_test_results) + match, _ = compare_test_results(test_results, mutated_test_results) + assert not match # This fto code stopped using a helper class. it should still pass no_helper1_fto_code = """ @@ -1170,10 +1175,304 @@ def target_function(self): ) # Remove instrumentation FunctionOptimizer.write_code_and_helpers(candidate_fto_code, candidate_helper_code, fto.file_path) - assert compare_test_results(test_results, no_helper1_test_results) + match, _ = compare_test_results(test_results, no_helper1_test_results) + assert match finally: test_path.unlink(missing_ok=True) fto_file_path.unlink(missing_ok=True) helper_path_1.unlink(missing_ok=True) helper_path_2.unlink(missing_ok=True) + +def test_instrument_codeflash_capture_and_run_tests_2() -> None: + # End to end run that instruments code and runs tests. Made to be similar to code used in the optimizer.py + test_code = """import math +import pytest +from typing import List, Tuple, Optional +from code_to_optimize.tests.pytest.fto_file import calculate_portfolio_metrics + +def test_calculate_portfolio_metrics(): + # Test case 1: Basic portfolio + investments = [ + ('Stocks', 0.6, 0.12), + ('Bonds', 0.3, 0.04), + ('Cash', 0.1, 0.01) + ] + + result = calculate_portfolio_metrics(investments) + + # Check weighted return calculation + expected_return = 0.6*0.12 + 0.3*0.04 + 0.1*0.01 + assert abs(result['weighted_return'] - expected_return) < 1e-10 + + # Check volatility calculation + expected_vol = math.sqrt((0.6*0.12)**2 + (0.3*0.04)**2 + (0.1*0.01)**2) + assert abs(result['volatility'] - expected_vol) < 1e-10 + + # Check Sharpe ratio + expected_sharpe = (expected_return - 0.02) / expected_vol + assert abs(result['sharpe_ratio'] - expected_sharpe) < 1e-10 + + # Check best/worst performers + assert result['best_performing'][0] == 'Stocks' + assert result['worst_performing'][0] == 'Cash' + assert result['total_assets'] == 3 + +def test_empty_investments(): + with pytest.raises(ValueError, match="Investments list cannot be empty"): + calculate_portfolio_metrics([]) + +def test_weights_not_sum_to_one(): + investments = [('Stock', 0.5, 0.1), ('Bond', 0.4, 0.05)] + with pytest.raises(ValueError, match="Portfolio weights must sum to 1.0"): + calculate_portfolio_metrics(investments) + +def test_zero_volatility(): + investments = [('Cash', 1.0, 0.0)] + result = calculate_portfolio_metrics(investments, risk_free_rate=0.0) + assert result['sharpe_ratio'] == 0.0 + assert result['volatility'] == 0.0 +""" + + original_code = """import math +from typing import List, Tuple, Optional + +def calculate_portfolio_metrics( + investments: List[Tuple[str, float, float]], + risk_free_rate: float = 0.02 +) -> dict: + if not investments: + raise ValueError("Investments list cannot be empty") + + if abs(sum(weight for _, weight, _ in investments) - 1.0) > 1e-10: + raise ValueError("Portfolio weights must sum to 1.0") + + # Calculate weighted return + weighted_return = sum(weight * ret for _, weight, ret in investments) + + # Calculate portfolio volatility (simplified) + volatility = math.sqrt(sum((weight * ret) ** 2 for _, weight, ret in investments)) + + # Calculate Sharpe ratio + if volatility == 0: + sharpe_ratio = 0.0 + else: + sharpe_ratio = (weighted_return - risk_free_rate) / volatility + + # Find best and worst performing assets + best_asset = max(investments, key=lambda x: x[2]) + worst_asset = min(investments, key=lambda x: x[2]) + + return { + 'weighted_return': round(weighted_return, 6), + 'volatility': round(volatility, 6), + 'sharpe_ratio': round(sharpe_ratio, 6), + 'best_performing': (best_asset[0], round(best_asset[2], 6)), + 'worst_performing': (worst_asset[0], round(worst_asset[2], 6)), + 'total_assets': len(investments) + } +""" + test_dir = (Path(__file__).parent.parent / "code_to_optimize" / "tests" / "pytest").resolve() + test_file_name = "test_multiple_helpers.py" + + fto_file_name = "fto_file.py" + + test_path = test_dir / test_file_name + test_path_perf = test_dir / "test_multiple_helpers_perf.py" + fto_file_path = test_dir / fto_file_name + + tests_root = Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/" + project_root_path = (Path(__file__).parent / "..").resolve() + + try: + with fto_file_path.open("w") as f: + f.write(original_code) + with test_path.open("w") as f: + f.write(test_code) + + fto = FunctionToOptimize("calculate_portfolio_metrics", fto_file_path, parents=[]) + file_path_to_helper_class = { + } + instrument_codeflash_capture(fto, file_path_to_helper_class, tests_root) + test_env = os.environ.copy() + test_env["CODEFLASH_TEST_ITERATION"] = "0" + test_env["CODEFLASH_LOOP_INDEX"] = "1" + + test_type = TestType.EXISTING_UNIT_TEST + test_config = TestConfig( + tests_root=tests_root, + tests_project_rootdir=project_root_path, + project_root_path=project_root_path, + test_framework="pytest", + pytest_cmd="pytest", + ) + func_optimizer = FunctionOptimizer(function_to_optimize=fto, test_cfg=test_config) + func_optimizer.test_files = TestFiles( + test_files=[ + TestFile( + instrumented_behavior_file_path=test_path, + test_type=test_type, + original_file_path=test_path, + benchmarking_file_path=test_path_perf, + ) + ] + ) + # Code in optimizer.py + # Instrument codeflash capture + candidate_fto_code = Path(fto.file_path).read_text("utf-8") + candidate_helper_code = {} + for file_path in file_path_to_helper_class: + candidate_helper_code[file_path] = Path(file_path).read_text("utf-8") + file_path_to_helper_classes = { + } + instrument_codeflash_capture(fto, file_path_to_helper_classes, tests_root) + + test_results, coverage_data = func_optimizer.run_and_parse_tests( + testing_type=TestingMode.BEHAVIOR, + test_env=test_env, + test_files=func_optimizer.test_files, + optimization_iteration=0, + pytest_min_loops=1, + pytest_max_loops=1, + testing_time=0.1, + ) + + # Remove instrumentation + FunctionOptimizer.write_code_and_helpers(candidate_fto_code, candidate_helper_code, fto.file_path) + + # Now, let's say we optimize the code and make changes. + new_fto_code = """import math +from typing import List, Tuple, Optional + +def calculate_portfolio_metrics( + investments: List[Tuple[str, float, float]], + risk_free_rate: float = 0.02 +) -> dict: + if not investments: + raise ValueError("Investments list cannot be empty") + + total_weight = sum(w for _, w, _ in investments) + if total_weight != 1.0: # Should use tolerance check + raise ValueError("Portfolio weights must sum to 1.0") + + weighted_return = 1.0 + for _, weight, ret in investments: + weighted_return *= (1 + ret) ** weight + weighted_return = weighted_return - 1.0 # Convert back from geometric + + returns = [r for _, _, r in investments] + mean_return = sum(returns) / len(returns) + volatility = math.sqrt(sum((r - mean_return) ** 2 for r in returns) / len(returns)) + + # BUG 4: Sharpe ratio calculation is correct but uses wrong inputs + if volatility == 0: + sharpe_ratio = 0.0 + else: + sharpe_ratio = (weighted_return - risk_free_rate) / volatility + + def risk_adjusted_return(return_val, weight): + return (return_val - risk_free_rate) / (weight * return_val) if weight * return_val != 0 else return_val + + best_asset = max(investments, key=lambda x: risk_adjusted_return(x[2], x[1])) + worst_asset = min(investments, key=lambda x: risk_adjusted_return(x[2], x[1])) + + return { + "weighted_return": round(weighted_return, 6), + "volatility": 2, + "sharpe_ratio": round(sharpe_ratio, 6), + "best_performing": (best_asset[0], round(best_asset[2], 6)), + "worst_performing": (worst_asset[0], round(worst_asset[2], 6)), + "total_assets": len(investments), + } +""" + with fto_file_path.open("w") as f: + f.write(new_fto_code) + # Instrument codeflash capture + candidate_fto_code = Path(fto.file_path).read_text("utf-8") + candidate_helper_code = {} + for file_path in file_path_to_helper_class: + candidate_helper_code[file_path] = Path(file_path).read_text("utf-8") + file_path_to_helper_classes = {} + instrument_codeflash_capture(fto, file_path_to_helper_classes, tests_root) + modified_test_results, coverage_data = func_optimizer.run_and_parse_tests( + testing_type=TestingMode.BEHAVIOR, + test_env=test_env, + test_files=func_optimizer.test_files, + optimization_iteration=0, + pytest_min_loops=1, + pytest_max_loops=1, + testing_time=0.1, + ) + # Remove instrumentation + FunctionOptimizer.write_code_and_helpers(candidate_fto_code, candidate_helper_code, fto.file_path) + matched, diffs = compare_test_results(test_results, modified_test_results) + + assert not matched + + new_fixed_code = """import math +from typing import List, Tuple, Optional + +def calculate_portfolio_metrics( + investments: List[Tuple[str, float, float]], + risk_free_rate: float = 0.02 +) -> dict: + if not investments: + raise ValueError("Investments list cannot be empty") + + # Tolerant weight check (matches original) + total_weight = sum(weight for _, weight, _ in investments) + if abs(total_weight - 1.0) > 1e-10: + raise ValueError("Portfolio weights must sum to 1.0") + + # Same weighted return as original + weighted_return = sum(weight * ret for _, weight, ret in investments) + + # Same volatility formula as original + volatility = math.sqrt(sum((weight * ret) ** 2 for _, weight, ret in investments)) + + # Same Sharpe ratio logic + if volatility == 0: + sharpe_ratio = 0.0 + else: + sharpe_ratio = (weighted_return - risk_free_rate) / volatility + + # Same best/worst logic (based on return only) + best_asset = max(investments, key=lambda x: x[2]) + worst_asset = min(investments, key=lambda x: x[2]) + + return { + "weighted_return": round(weighted_return, 6), + "volatility": round(volatility, 6), + "sharpe_ratio": round(sharpe_ratio, 6), + "best_performing": (best_asset[0], round(best_asset[2], 6)), + "worst_performing": (worst_asset[0], round(worst_asset[2], 6)), + "total_assets": len(investments), + } +""" + with fto_file_path.open("w") as f: + f.write(new_fixed_code) + candidate_fto_code = Path(fto.file_path).read_text("utf-8") + candidate_helper_code = {} + for file_path in file_path_to_helper_class: + candidate_helper_code[file_path] = Path(file_path).read_text("utf-8") + file_path_to_helper_classes = {} + instrument_codeflash_capture(fto, file_path_to_helper_classes, tests_root) + modified_test_results_2, coverage_data = func_optimizer.run_and_parse_tests( + testing_type=TestingMode.BEHAVIOR, + test_env=test_env, + test_files=func_optimizer.test_files, + optimization_iteration=0, + pytest_min_loops=1, + pytest_max_loops=1, + testing_time=0.1, + ) + # Remove instrumentation + FunctionOptimizer.write_code_and_helpers(candidate_fto_code, candidate_helper_code, fto.file_path) + matched, diffs = compare_test_results(test_results, modified_test_results_2) + # now the test should match and no diffs should be found + assert len(diffs) == 0 + assert matched + + finally: + test_path.unlink(missing_ok=True) + fto_file_path.unlink(missing_ok=True) \ No newline at end of file diff --git a/tests/test_comparator.py b/tests/test_comparator.py index 06d178f95..6c2781229 100644 --- a/tests/test_comparator.py +++ b/tests/test_comparator.py @@ -1176,7 +1176,8 @@ def test_compare_results_fn(): ) ) - assert compare_test_results(original_results, new_results_1) + match, _ = compare_test_results(original_results, new_results_1) + assert match new_results_2 = TestResults() new_results_2.add( @@ -1199,7 +1200,8 @@ def test_compare_results_fn(): ) ) - assert not compare_test_results(original_results, new_results_2) + match, _ = compare_test_results(original_results, new_results_2) + assert not match new_results_3 = TestResults() new_results_3.add( @@ -1241,7 +1243,8 @@ def test_compare_results_fn(): ) ) - assert compare_test_results(original_results, new_results_3) + match, _ = compare_test_results(original_results, new_results_3) + assert match new_results_4 = TestResults() new_results_4.add( @@ -1264,7 +1267,8 @@ def test_compare_results_fn(): ) ) - assert not compare_test_results(original_results, new_results_4) + match, _ = compare_test_results(original_results, new_results_4) + assert not match new_results_5_baseline = TestResults() new_results_5_baseline.add( @@ -1308,7 +1312,8 @@ def test_compare_results_fn(): ) ) - assert not compare_test_results(new_results_5_baseline, new_results_5_opt) + match, _ = compare_test_results(new_results_5_baseline, new_results_5_opt) + assert not match new_results_6_baseline = TestResults() new_results_6_baseline.add( @@ -1352,9 +1357,11 @@ def test_compare_results_fn(): ) ) - assert not compare_test_results(new_results_6_baseline, new_results_6_opt) + match, _ = compare_test_results(new_results_6_baseline, new_results_6_opt) + assert not match - assert not compare_test_results(TestResults(), TestResults()) + match, _ = compare_test_results(TestResults(), TestResults()) + assert not match def test_exceptions(): diff --git a/tests/test_instrument_all_and_run.py b/tests/test_instrument_all_and_run.py index ece7d38b0..7bdfa364b 100644 --- a/tests/test_instrument_all_and_run.py +++ b/tests/test_instrument_all_and_run.py @@ -223,7 +223,8 @@ def test_sort(): result: [0, 1, 2, 3, 4, 5] """ assert out_str == results2[0].stdout - assert compare_test_results(test_results, results2) + match, _ = compare_test_results(test_results, results2) + assert match finally: fto_path.write_text(original_code, "utf-8") test_path.unlink(missing_ok=True) @@ -368,7 +369,8 @@ def test_sort(): assert test_results[1].return_value == ([0, 1, 2, 3, 4, 5],) out_str = """codeflash stdout : BubbleSorter.sorter() called\n""" assert test_results[1].stdout == out_str - assert compare_test_results(test_results, test_results) + match, _ = compare_test_results(test_results, test_results) + assert match assert test_results[2].id.function_getting_tested == "BubbleSorter.__init__" assert test_results[2].id.test_function_name == "test_sort" assert test_results[2].did_pass @@ -396,7 +398,8 @@ def test_sort(): testing_time=0.1, ) - assert compare_test_results(test_results, results2) + match, _ = compare_test_results(test_results, results2) + assert match # Replace with optimized code that mutated instance attribute optimized_code = """ @@ -491,7 +494,8 @@ def sorter(self, arr): ) assert new_test_results[3].runtime > 0 assert new_test_results[3].did_pass - assert not compare_test_results(test_results, new_test_results) + match, _ = compare_test_results(test_results, new_test_results) + assert not match finally: fto_path.write_text(original_code, "utf-8") @@ -630,7 +634,8 @@ def test_sort(): out_str = """codeflash stdout : BubbleSorter.sorter_classmethod() called """ assert test_results[0].stdout == out_str - assert compare_test_results(test_results, test_results) + match, _ = compare_test_results(test_results, test_results) + assert match assert test_results[1].id.function_getting_tested == "BubbleSorter.sorter_classmethod" assert test_results[1].id.iteration_id == "4_0" @@ -655,7 +660,8 @@ def test_sort(): testing_time=0.1, ) - assert compare_test_results(test_results, results2) + match, _ = compare_test_results(test_results, results2) + assert match finally: fto_path.write_text(original_code, "utf-8") @@ -794,7 +800,8 @@ def test_sort(): out_str = """codeflash stdout : BubbleSorter.sorter_staticmethod() called """ assert test_results[0].stdout == out_str - assert compare_test_results(test_results, test_results) + match, _ = compare_test_results(test_results, test_results) + assert match assert test_results[1].id.function_getting_tested == "BubbleSorter.sorter_staticmethod" assert test_results[1].id.iteration_id == "4_0" @@ -819,7 +826,8 @@ def test_sort(): testing_time=0.1, ) - assert compare_test_results(test_results, results2) + match, _ = compare_test_results(test_results, results2) + assert match finally: fto_path.write_text(original_code, "utf-8") diff --git a/tests/test_instrumentation_run_results_aiservice.py b/tests/test_instrumentation_run_results_aiservice.py index cae2c76f1..03556718d 100644 --- a/tests/test_instrumentation_run_results_aiservice.py +++ b/tests/test_instrumentation_run_results_aiservice.py @@ -221,10 +221,10 @@ def sorter(self, arr): testing_time=0.1, ) # assert test_results_mutated_attr[0].return_value[1]["self"].x == 1 TODO: add self as input to function - assert compare_test_results( + match, _ = compare_test_results( test_results, test_results_mutated_attr ) # Without codeflash capture, the init state was not verified, and the results are verified as correct even with the attribute mutated - + assert match assert test_results_mutated_attr[0].stdout == "codeflash stdout : BubbleSorter.sorter() called\n" finally: fto_path.write_text(original_code, "utf-8") @@ -403,9 +403,10 @@ def sorter(self, arr): assert test_results_mutated_attr[0].return_value[0] == {"x": 1} assert test_results_mutated_attr[0].verification_type == VerificationType.INIT_STATE_FTO assert test_results_mutated_attr[0].stdout == "" - assert not compare_test_results( + match,_ = compare_test_results( test_results, test_results_mutated_attr ) # The test should fail because the instance attribute was mutated + assert not match # Replace with optimized code that did not mutate existing instance attribute, but added a new one optimized_code_new_attr = """ import sys @@ -457,9 +458,10 @@ def sorter(self, arr): assert test_results_new_attr[0].stdout == "" # assert test_results_new_attr[1].return_value[1]["self"].x == 0 TODO: add self as input # assert test_results_new_attr[1].return_value[1]["self"].y == 2 TODO: add self as input - assert compare_test_results( + match,_ = compare_test_results( test_results, test_results_new_attr ) # The test should pass because the instance attribute was not mutated, only a new one was added + assert match finally: fto_path.write_text(original_code, "utf-8") test_path.unlink(missing_ok=True) diff --git a/tests/test_pickle_patcher.py b/tests/test_pickle_patcher.py index c67883c12..c05384d03 100644 --- a/tests/test_pickle_patcher.py +++ b/tests/test_pickle_patcher.py @@ -427,8 +427,8 @@ def bubble_sort_with_unused_socket(data_container): testing_time=1.0, ) assert len(optimized_test_results_unused_socket) == 1 - verification_result = compare_test_results(test_results_unused_socket, optimized_test_results_unused_socket) - assert verification_result is True + match, _ = compare_test_results(test_results_unused_socket, optimized_test_results_unused_socket) + assert match # Remove the previous instrumentation replay_test_path.write_text(original_replay_test_code) @@ -517,8 +517,8 @@ def bubble_sort_with_used_socket(data_container): assert test_results_used_socket.test_results[0].did_pass is False # Even though tests threw the same error, we reject this as the behavior of the unpickleable object could not be determined. - assert compare_test_results(test_results_used_socket, optimized_test_results_used_socket) is False - + match, _ = compare_test_results(test_results_used_socket, optimized_test_results_used_socket) + assert not match finally: # cleanup output_file.unlink(missing_ok=True)