Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
5830a70
quick and dirty
mohammedahmed18 Nov 27, 2025
3e0440b
safter
mohammedahmed18 Nov 27, 2025
eb16cb2
Optimize parse_test_failures_from_stdout
codeflash-ai[bot] Nov 27, 2025
168118a
Merge pull request #946 from codeflash-ai/codeflash/optimize-pr945-20…
mohammedahmed18 Nov 27, 2025
a7f8816
fix tests
mohammedahmed18 Nov 27, 2025
4e9f894
linting
mohammedahmed18 Nov 27, 2025
1c9abaf
did it pass ?
mohammedahmed18 Nov 28, 2025
0b2d894
revert test optimization
mohammedahmed18 Nov 28, 2025
ecfa89f
cleaner
mohammedahmed18 Nov 28, 2025
6ea2545
test: try to fix the candidate and see if the diff is empty
mohammedahmed18 Nov 28, 2025
fe68772
capture all test discrepancies
Nov 30, 2025
ed39ec8
do the repair in main loop
Nov 30, 2025
142da4c
todo write backend endpoint
Dec 1, 2025
5a7c356
need to test now
Dec 1, 2025
8a28d0d
Merge branch 'feat/feedback-loop-for-unmatched-test-results' of githu…
mohammedahmed18 Dec 1, 2025
5ed5dfc
works, figure out logging
Dec 1, 2025
fe33c82
local db logging
Dec 1, 2025
83814be
ready to run experiments
Dec 1, 2025
0325444
logging fix
Dec 1, 2025
9f7ed90
handle test class methods for the test diff
mohammedahmed18 Dec 1, 2025
1ddc87c
Merge branch 'feat/feedback-loop-for-unmatched-test-results' of githu…
mohammedahmed18 Dec 1, 2025
6060ffb
codeflash suggestion
mohammedahmed18 Dec 2, 2025
1120d64
safer parsing
mohammedahmed18 Dec 2, 2025
c2e037a
better parsing for pytest stdout
mohammedahmed18 Dec 2, 2025
5703889
Merge branch 'feat/feedback-loop-for-unmatched-test-results' of githu…
mohammedahmed18 Dec 2, 2025
bd1ebf4
temp logging
Dec 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 55 additions & 1 deletion codeflash/api/aiservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.ExperimentMetadata import ExperimentMetadata
from codeflash.models.models import AIServiceRefinerRequest
from codeflash.models.models import AIServiceCodeRepairRequest, AIServiceRefinerRequest
from codeflash.result.explanation import Explanation


Expand Down Expand Up @@ -294,6 +294,60 @@ def optimize_python_code_refinement(self, request: list[AIServiceRefinerRequest]
console.rule()
return []

def optimize_python_code_repair(self, request: list[AIServiceCodeRepairRequest]) -> list[OptimizedCandidate]:
"""Optimize the given python code for performance by making a request to the Django endpoint.

Args:
request: A list of optimization candidate details for refinement

Returns:
-------
- List[OptimizationCandidate]: A list of Optimization Candidates.

"""
payload = [
{
"optimization_id": opt.optimization_id,
"original_source_code": opt.original_source_code,
"modified_source_code": opt.modified_source_code,
"test_details": opt.test_details,
"trace_id": opt.trace_id,
}
for opt in request
]
# logger.debug(f"Repair {len(request)} optimizations…")
console.rule()
try:
response = self.make_ai_service_request("/code_repair", payload=payload, timeout=120)
except requests.exceptions.RequestException as e:
logger.exception(f"Error generating optimization repair: {e}")
ph("cli-optimize-error-caught", {"error": str(e)})
return []

if response.status_code == 200:
refined_optimizations = response.json()["code_repairs"]
# logger.debug(f"Generated {len(refined_optimizations)} candidate refinements.")
console.rule()

refinements = self._get_valid_candidates(refined_optimizations)
return [
OptimizedCandidate(
source_code=c.source_code,
explanation=c.explanation,
optimization_id=c.optimization_id[:-4] + "cdrp",
)
for c in refinements
]

try:
error = response.json()["error"]
except Exception:
error = response.text
logger.error(f"Error generating optimized candidates: {response.status_code} - {error}")
ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error})
console.rule()
return []

def get_new_explanation( # noqa: D417
self,
source_code: str,
Expand Down
49 changes: 49 additions & 0 deletions codeflash/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections import Counter, defaultdict
from typing import TYPE_CHECKING

import libcst as cst
from rich.tree import Tree

from codeflash.cli_cmds.console import DEBUG_MODE, lsp_log
Expand Down Expand Up @@ -47,6 +48,15 @@ class AIServiceRefinerRequest:
function_references: str | None = None


@dataclass(frozen=True)
class AIServiceCodeRepairRequest:
optimization_id: str
original_source_code: str
modified_source_code: str
test_details: str
trace_id: str


# If the method spam is in the class Ham, which is at the top level of the module eggs in the package foo, the fully
# qualified name of the method is foo.eggs.Ham.spam, its qualified name is Ham.spam, and its name is spam. The full name
# of the module is foo.eggs.
Expand Down Expand Up @@ -505,6 +515,42 @@ def id(self) -> str:
f"{self.function_getting_tested}:{self.iteration_id}"
)

# TestSuiteClass.test_function_name
def test_fn_qualified_name(self) -> str:
# Use f-string with inline conditional to reduce string concatenation operations
return (
f"{self.test_class_name}.{self.test_function_name}"
if self.test_class_name
else str(self.test_function_name)
)

def find_func_in_class(self, class_node: cst.ClassDef, func_name: str) -> Optional[cst.FunctionDef]:
for stmt in class_node.body.body:
if isinstance(stmt, cst.FunctionDef) and stmt.name.value == func_name:
return stmt
return None

def get_src_code(self, test_path: Path) -> Optional[str]:
if not test_path.exists():
return None
test_src = test_path.read_text(encoding="utf-8")
module_node = cst.parse_module(test_src)

if self.test_class_name:
for stmt in module_node.body:
if isinstance(stmt, cst.ClassDef) and stmt.name.value == self.test_class_name:
func_node = self.find_func_in_class(stmt, self.test_function_name)
if func_node:
return module_node.code_for_node(func_node).strip()
# class not found
return None

# Otherwise, look for a top level function
for stmt in module_node.body:
if isinstance(stmt, cst.FunctionDef) and stmt.name.value == self.test_function_name:
return module_node.code_for_node(stmt).strip()
return None

@staticmethod
def from_str_id(string_id: str, iteration_id: str | None = None) -> InvocationId:
components = string_id.split(":")
Expand Down Expand Up @@ -549,7 +595,10 @@ class TestResults(BaseModel): # noqa: PLW1641
# also we don't support deletion of test results elements - caution is advised
test_results: list[FunctionTestInvocation] = []
test_result_idx: dict[str, int] = {}

perf_stdout: Optional[str] = None
# mapping between test function name and stdout failure message
test_failures: Optional[dict[str, str]] = None

def add(self, function_test_invocation: FunctionTestInvocation) -> None:
unique_id = function_test_invocation.unique_invocation_loop_id
Expand Down
Loading
Loading