Skip to content

Commit e73cb53

Browse files
committed
cc
1 parent 86b95ad commit e73cb53

File tree

1 file changed

+131
-50
lines changed

1 file changed

+131
-50
lines changed

codeflash/optimization/function_optimizer.py

Lines changed: 131 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33
import ast
44
import concurrent.futures
55
import os
6+
import queue
67
import random
78
import subprocess
89
import time
910
import uuid
10-
from collections import defaultdict, deque
11+
from collections import defaultdict
1112
from pathlib import Path
1213
from typing import TYPE_CHECKING
1314

@@ -104,6 +105,83 @@
104105
from codeflash.verification.verification_utils import TestConfig
105106

106107

108+
class CandidateProcessor:
109+
"""Handles candidate processing using a queue-based approach."""
110+
111+
def __init__(
112+
self,
113+
initial_candidates: list,
114+
future_line_profile_results: concurrent.futures.Future,
115+
future_all_refinements: list,
116+
) -> None:
117+
self.candidate_queue = queue.Queue()
118+
self.line_profiler_done = False
119+
self.refinement_done = False
120+
self.original_len = len(initial_candidates)
121+
122+
# Initialize queue with initial candidates
123+
for candidate in initial_candidates:
124+
self.candidate_queue.put(candidate)
125+
126+
self.future_line_profile_results = future_line_profile_results
127+
self.future_all_refinements = future_all_refinements
128+
129+
def get_next_candidate(self) -> OptimizedCandidate | None:
130+
"""Get the next candidate from the queue, handling async results as needed."""
131+
try:
132+
return self.candidate_queue.get_nowait()
133+
except queue.Empty:
134+
return self._handle_empty_queue()
135+
136+
def _handle_empty_queue(self) -> OptimizedCandidate | None:
137+
"""Handle empty queue by checking for pending async results."""
138+
if not self.line_profiler_done:
139+
return self._process_line_profiler_results()
140+
if self.line_profiler_done and not self.refinement_done:
141+
return self._process_refinement_results()
142+
return None # All done
143+
144+
def _process_line_profiler_results(self) -> OptimizedCandidate | None:
145+
"""Process line profiler results and add to queue."""
146+
logger.debug("all candidates processed, await candidates from line profiler")
147+
concurrent.futures.wait([self.future_line_profile_results])
148+
line_profile_results = self.future_line_profile_results.result()
149+
150+
for candidate in line_profile_results:
151+
self.candidate_queue.put(candidate)
152+
153+
self.original_len += len(line_profile_results)
154+
logger.info(f"Added results from line profiler to candidates, total candidates now: {self.original_len}")
155+
self.line_profiler_done = True
156+
157+
return self.get_next_candidate()
158+
159+
def _process_refinement_results(self) -> OptimizedCandidate | None:
160+
"""Process refinement results and add to queue."""
161+
concurrent.futures.wait(self.future_all_refinements)
162+
refinement_response = []
163+
164+
for future_refinement in self.future_all_refinements:
165+
possible_refinement = future_refinement.result()
166+
if len(possible_refinement) > 0:
167+
refinement_response.append(possible_refinement[0])
168+
169+
for candidate in refinement_response:
170+
self.candidate_queue.put(candidate)
171+
172+
self.original_len += len(refinement_response)
173+
logger.info(
174+
f"Added {len(refinement_response)} candidates from refinement, total candidates now: {self.original_len}"
175+
)
176+
self.refinement_done = True
177+
178+
return self.get_next_candidate()
179+
180+
def is_done(self) -> bool:
181+
"""Check if processing is complete."""
182+
return self.line_profiler_done and self.refinement_done and self.candidate_queue.empty()
183+
184+
107185
class FunctionOptimizer:
108186
def __init__(
109187
self,
@@ -378,15 +456,13 @@ def determine_best_candidate(
378456
f"{self.function_to_optimize.qualified_name}…"
379457
)
380458
console.rule()
381-
candidates = deque(candidates)
382-
refinement_done = False
383-
line_profiler_done = False
459+
384460
future_all_refinements: list[concurrent.futures.Future] = []
385461
ast_code_to_id = {}
386462
valid_optimizations = []
387463
optimizations_post = {} # we need to overwrite some opt candidates' code strings as they are no longer evaluated, instead their shorter/longer versions might be evaluated
388-
# Start a new thread for AI service request, start loop in main thread
389-
# check if aiservice request is complete, when it is complete, append result to the candidates list
464+
465+
# Start a new thread for AI service request
390466
ai_service_client = self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client
391467
future_line_profile_results = self.executor.submit(
392468
ai_service_client.optimize_python_code_line_profiler,
@@ -401,49 +477,23 @@ def determine_best_candidate(
401477
if self.experiment_id
402478
else None,
403479
)
480+
481+
# Initialize candidate processor
482+
processor = CandidateProcessor(candidates, future_line_profile_results, future_all_refinements)
404483
candidate_index = 0
405-
original_len = len(candidates)
406-
# TODO : We need to rewrite this candidate loop as a class, the container which has candidates receives new candidates at unknown times due to the async nature of lp and refinement calls,
407-
# TODO : in addition, the refinement calls depend on line profiler calls being complete so we need to check that reliably
408-
# write a class with queue.Queue architecture
409-
while True:
410-
try:
411-
if len(candidates) > 0:
412-
candidate = candidates.popleft()
413-
else:
414-
if not line_profiler_done:
415-
logger.debug("all candidates processed, await candidates from line profiler")
416-
concurrent.futures.wait([future_line_profile_results])
417-
line_profile_results = future_line_profile_results.result()
418-
candidates.extend(line_profile_results)
419-
original_len += len(line_profile_results)
420-
logger.info(
421-
f"Added results from line profiler to candidates, total candidates now: {original_len}"
422-
)
423-
line_profiler_done = True
424-
continue
425-
if line_profiler_done and not refinement_done:
426-
concurrent.futures.wait(future_all_refinements)
427-
refinement_response = []
428-
for future_refinement in future_all_refinements:
429-
possible_refinement = future_refinement.result()
430-
if len(possible_refinement) > 0: # if the api returns a valid response
431-
refinement_response.append(possible_refinement[0])
432-
candidates.extend(refinement_response)
433-
original_len += len(refinement_response)
434-
logger.info(
435-
f"Added {len(refinement_response)} candidates from refinement, total candidates now: {original_len}"
436-
)
437-
refinement_done = True
438-
continue
439-
if line_profiler_done and refinement_done:
440-
logger.debug("everything done, exiting")
441-
break
442484

485+
# Process candidates using queue-based approach
486+
while not processor.is_done():
487+
candidate = processor.get_next_candidate()
488+
if candidate is None:
489+
logger.debug("everything done, exiting")
490+
break
491+
492+
try:
443493
candidate_index += 1
444494
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.bin")).unlink(missing_ok=True)
445495
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.sqlite")).unlink(missing_ok=True)
446-
logger.info(f"Optimization candidate {candidate_index}/{original_len}:")
496+
logger.info(f"Optimization candidate {candidate_index}/{processor.original_len}:")
447497
code_print(candidate.source_code.flat)
448498
# map ast normalized code to diff len, unnormalized code
449499
# map opt id to the shortest unnormalized code
@@ -468,7 +518,7 @@ def determine_best_candidate(
468518
# check if this code has been evaluated before by checking the ast normalized code string
469519
normalized_code = ast.unparse(ast.parse(candidate.source_code.flat.strip()))
470520
if normalized_code in ast_code_to_id:
471-
logger.warning(
521+
logger.info(
472522
"Current candidate has been encountered before in testing, Skipping optimization candidate."
473523
)
474524
past_opt_id = ast_code_to_id[normalized_code]["optimization_id"]
@@ -746,7 +796,9 @@ def reformat_code_and_helpers(
746796
file_to_code_context = optimized_context.file_to_path()
747797
optimized_code = file_to_code_context.get(str(path.relative_to(self.project_root)), "")
748798

749-
new_code = format_code(self.args.formatter_cmds, path, optimized_code=optimized_code, check_diff=True)
799+
new_code = format_code(
800+
self.args.formatter_cmds, path, optimized_code=optimized_code, check_diff=True, exit_on_failure=False
801+
)
750802
if should_sort_imports:
751803
new_code = sort_imports(new_code)
752804

@@ -755,7 +807,11 @@ def reformat_code_and_helpers(
755807
module_abspath = hp.file_path
756808
hp_source_code = hp.source_code
757809
formatted_helper_code = format_code(
758-
self.args.formatter_cmds, module_abspath, optimized_code=hp_source_code, check_diff=True
810+
self.args.formatter_cmds,
811+
module_abspath,
812+
optimized_code=hp_source_code,
813+
check_diff=True,
814+
exit_on_failure=False,
759815
)
760816
if should_sort_imports:
761817
formatted_helper_code = sort_imports(formatted_helper_code)
@@ -1153,7 +1209,6 @@ def find_and_process_best_optimization(
11531209
original_helper_code,
11541210
code_context,
11551211
)
1156-
self.log_successful_optimization(explanation, generated_tests, exp_type)
11571212
return best_optimization
11581213

11591214
def process_review(
@@ -1227,7 +1282,10 @@ def process_review(
12271282
file_path=explanation.file_path,
12281283
benchmark_details=explanation.benchmark_details,
12291284
)
1230-
console.print(Panel(new_explanation_raw_str, title="Best Candidate Explanation", border_style="blue"))
1285+
self.log_successful_optimization(new_explanation, generated_tests, exp_type)
1286+
1287+
best_optimization.explanation_v2 = new_explanation.explanation_message()
1288+
12311289
data = {
12321290
"original_code": original_code_combined,
12331291
"new_code": new_code_combined,
@@ -1240,6 +1298,7 @@ def process_review(
12401298
"coverage_message": coverage_message,
12411299
"replay_tests": replay_tests,
12421300
"concolic_tests": concolic_tests,
1301+
"root_dir": self.project_root,
12431302
}
12441303

12451304
raise_pr = not self.args.no_pr
@@ -1248,13 +1307,35 @@ def process_review(
12481307
data["git_remote"] = self.args.git_remote
12491308
check_create_pr(**data)
12501309
elif self.args.staging_review:
1251-
create_staging(**data)
1310+
response = create_staging(**data)
1311+
if response.status_code == 200:
1312+
staging_url = f"https://app.codeflash.ai/review-optimizations/{self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id}"
1313+
console.print(
1314+
Panel(
1315+
f"[bold green]✅ Staging created:[/bold green]\n[link={staging_url}]{staging_url}[/link]",
1316+
title="Staging Link",
1317+
border_style="green",
1318+
)
1319+
)
1320+
else:
1321+
console.print(
1322+
Panel(
1323+
f"[bold red]❌ Failed to create staging[/bold red]\nStatus: {response.status_code}",
1324+
title="Staging Error",
1325+
border_style="red",
1326+
)
1327+
)
1328+
12521329
else:
12531330
# Mark optimization success since no PR will be created
12541331
mark_optimization_success(
12551332
trace_id=self.function_trace_id, is_optimization_found=best_optimization is not None
12561333
)
12571334

1335+
# If worktree mode, do not revert code and helpers,, otherwise we would have an empty diff when writing the patch in the lsp
1336+
if self.args.worktree:
1337+
return
1338+
12581339
if raise_pr and (
12591340
self.args.all
12601341
or env_utils.get_pr_number()

0 commit comments

Comments
 (0)