|
27 | 27 |
|
28 | 28 | from codeflash.discovery.functions_to_optimize import FunctionToOptimize |
29 | 29 | from codeflash.models.ExperimentMetadata import ExperimentMetadata |
30 | | - from codeflash.models.models import AIServiceRefinerRequest |
| 30 | + from codeflash.models.models import AIServiceCodeRepairRequest, AIServiceRefinerRequest |
31 | 31 | from codeflash.result.explanation import Explanation |
32 | 32 |
|
33 | 33 |
|
@@ -294,6 +294,59 @@ def optimize_python_code_refinement(self, request: list[AIServiceRefinerRequest] |
294 | 294 | console.rule() |
295 | 295 | return [] |
296 | 296 |
|
| 297 | + def optimize_python_code_repair(self, request: list[AIServiceCodeRepairRequest]) -> list[OptimizedCandidate]: |
| 298 | + """Optimize the given python code for performance by making a request to the Django endpoint. |
| 299 | +
|
| 300 | + Args: |
| 301 | + request: A list of optimization candidate details for refinement |
| 302 | +
|
| 303 | + Returns: |
| 304 | + ------- |
| 305 | + - List[OptimizationCandidate]: A list of Optimization Candidates. |
| 306 | +
|
| 307 | + """ |
| 308 | + payload = [ |
| 309 | + { |
| 310 | + "optimization_id": opt.optimization_id, |
| 311 | + "original_source_code": opt.original_source_code, |
| 312 | + "modified_source_code": opt.modified_source_code, |
| 313 | + "trace_id": opt.trace_id, |
| 314 | + } |
| 315 | + for opt in request |
| 316 | + ] |
| 317 | + # logger.debug(f"Repair {len(request)} optimizations…") |
| 318 | + console.rule() |
| 319 | + try: |
| 320 | + response = self.make_ai_service_request("/code_repair", payload=payload, timeout=120) |
| 321 | + except requests.exceptions.RequestException as e: |
| 322 | + logger.exception(f"Error generating optimization repair: {e}") |
| 323 | + ph("cli-optimize-error-caught", {"error": str(e)}) |
| 324 | + return [] |
| 325 | + |
| 326 | + if response.status_code == 200: |
| 327 | + refined_optimizations = response.json()["code_repairs"] |
| 328 | + logger.debug(f"Generated {len(refined_optimizations)} candidate refinements.") |
| 329 | + console.rule() |
| 330 | + |
| 331 | + refinements = self._get_valid_candidates(refined_optimizations) |
| 332 | + return [ |
| 333 | + OptimizedCandidate( |
| 334 | + source_code=c.source_code, |
| 335 | + explanation=c.explanation, |
| 336 | + optimization_id=c.optimization_id[:-4] + "cdrp", |
| 337 | + ) |
| 338 | + for c in refinements |
| 339 | + ] |
| 340 | + |
| 341 | + try: |
| 342 | + error = response.json()["error"] |
| 343 | + except Exception: |
| 344 | + error = response.text |
| 345 | + logger.error(f"Error generating optimized candidates: {response.status_code} - {error}") |
| 346 | + ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error}) |
| 347 | + console.rule() |
| 348 | + return [] |
| 349 | + |
297 | 350 | def get_new_explanation( # noqa: D417 |
298 | 351 | self, |
299 | 352 | source_code: str, |
|
0 commit comments