33import ast
44import concurrent .futures
55import os
6+ import queue
67import random
78import subprocess
89import time
910import uuid
10- from collections import defaultdict , deque
11+ from collections import defaultdict
1112from pathlib import Path
1213from typing import TYPE_CHECKING
1314
104105 from codeflash .verification .verification_utils import TestConfig
105106
106107
108+ class CandidateProcessor :
109+ """Handles candidate processing using a queue-based approach."""
110+
111+ def __init__ (
112+ self ,
113+ initial_candidates : list ,
114+ future_line_profile_results : concurrent .futures .Future ,
115+ future_all_refinements : list ,
116+ ) -> None :
117+ self .candidate_queue = queue .Queue ()
118+ self .line_profiler_done = False
119+ self .refinement_done = False
120+ self .original_len = len (initial_candidates )
121+
122+ # Initialize queue with initial candidates
123+ for candidate in initial_candidates :
124+ self .candidate_queue .put (candidate )
125+
126+ self .future_line_profile_results = future_line_profile_results
127+ self .future_all_refinements = future_all_refinements
128+
129+ def get_next_candidate (self ) -> OptimizedCandidate | None :
130+ """Get the next candidate from the queue, handling async results as needed."""
131+ try :
132+ return self .candidate_queue .get_nowait ()
133+ except queue .Empty :
134+ return self ._handle_empty_queue ()
135+
136+ def _handle_empty_queue (self ) -> OptimizedCandidate | None :
137+ """Handle empty queue by checking for pending async results."""
138+ if not self .line_profiler_done :
139+ return self ._process_line_profiler_results ()
140+ if self .line_profiler_done and not self .refinement_done :
141+ return self ._process_refinement_results ()
142+ return None # All done
143+
144+ def _process_line_profiler_results (self ) -> OptimizedCandidate | None :
145+ """Process line profiler results and add to queue."""
146+ logger .debug ("all candidates processed, await candidates from line profiler" )
147+ concurrent .futures .wait ([self .future_line_profile_results ])
148+ line_profile_results = self .future_line_profile_results .result ()
149+
150+ for candidate in line_profile_results :
151+ self .candidate_queue .put (candidate )
152+
153+ self .original_len += len (line_profile_results )
154+ logger .info (f"Added results from line profiler to candidates, total candidates now: { self .original_len } " )
155+ self .line_profiler_done = True
156+
157+ return self .get_next_candidate ()
158+
159+ def _process_refinement_results (self ) -> OptimizedCandidate | None :
160+ """Process refinement results and add to queue."""
161+ concurrent .futures .wait (self .future_all_refinements )
162+ refinement_response = []
163+
164+ for future_refinement in self .future_all_refinements :
165+ possible_refinement = future_refinement .result ()
166+ if len (possible_refinement ) > 0 :
167+ refinement_response .append (possible_refinement [0 ])
168+
169+ for candidate in refinement_response :
170+ self .candidate_queue .put (candidate )
171+
172+ self .original_len += len (refinement_response )
173+ logger .info (
174+ f"Added { len (refinement_response )} candidates from refinement, total candidates now: { self .original_len } "
175+ )
176+ self .refinement_done = True
177+
178+ return self .get_next_candidate ()
179+
180+ def is_done (self ) -> bool :
181+ """Check if processing is complete."""
182+ return self .line_profiler_done and self .refinement_done and self .candidate_queue .empty ()
183+
184+
107185class FunctionOptimizer :
108186 def __init__ (
109187 self ,
@@ -378,15 +456,13 @@ def determine_best_candidate(
378456 f"{ self .function_to_optimize .qualified_name } …"
379457 )
380458 console .rule ()
381- candidates = deque (candidates )
382- refinement_done = False
383- line_profiler_done = False
459+
384460 future_all_refinements : list [concurrent .futures .Future ] = []
385461 ast_code_to_id = {}
386462 valid_optimizations = []
387463 optimizations_post = {} # we need to overwrite some opt candidates' code strings as they are no longer evaluated, instead their shorter/longer versions might be evaluated
388- # Start a new thread for AI service request, start loop in main thread
389- # check if aiservice request is complete, when it is complete, append result to the candidates list
464+
465+ # Start a new thread for AI service request
390466 ai_service_client = self .aiservice_client if exp_type == "EXP0" else self .local_aiservice_client
391467 future_line_profile_results = self .executor .submit (
392468 ai_service_client .optimize_python_code_line_profiler ,
@@ -401,49 +477,23 @@ def determine_best_candidate(
401477 if self .experiment_id
402478 else None ,
403479 )
480+
481+ # Initialize candidate processor
482+ processor = CandidateProcessor (candidates , future_line_profile_results , future_all_refinements )
404483 candidate_index = 0
405- original_len = len (candidates )
406- # TODO : We need to rewrite this candidate loop as a class, the container which has candidates receives new candidates at unknown times due to the async nature of lp and refinement calls,
407- # TODO : in addition, the refinement calls depend on line profiler calls being complete so we need to check that reliably
408- # write a class with queue.Queue architecture
409- while True :
410- try :
411- if len (candidates ) > 0 :
412- candidate = candidates .popleft ()
413- else :
414- if not line_profiler_done :
415- logger .debug ("all candidates processed, await candidates from line profiler" )
416- concurrent .futures .wait ([future_line_profile_results ])
417- line_profile_results = future_line_profile_results .result ()
418- candidates .extend (line_profile_results )
419- original_len += len (line_profile_results )
420- logger .info (
421- f"Added results from line profiler to candidates, total candidates now: { original_len } "
422- )
423- line_profiler_done = True
424- continue
425- if line_profiler_done and not refinement_done :
426- concurrent .futures .wait (future_all_refinements )
427- refinement_response = []
428- for future_refinement in future_all_refinements :
429- possible_refinement = future_refinement .result ()
430- if len (possible_refinement ) > 0 : # if the api returns a valid response
431- refinement_response .append (possible_refinement [0 ])
432- candidates .extend (refinement_response )
433- original_len += len (refinement_response )
434- logger .info (
435- f"Added { len (refinement_response )} candidates from refinement, total candidates now: { original_len } "
436- )
437- refinement_done = True
438- continue
439- if line_profiler_done and refinement_done :
440- logger .debug ("everything done, exiting" )
441- break
442484
485+ # Process candidates using queue-based approach
486+ while not processor .is_done ():
487+ candidate = processor .get_next_candidate ()
488+ if candidate is None :
489+ logger .debug ("everything done, exiting" )
490+ break
491+
492+ try :
443493 candidate_index += 1
444494 get_run_tmp_file (Path (f"test_return_values_{ candidate_index } .bin" )).unlink (missing_ok = True )
445495 get_run_tmp_file (Path (f"test_return_values_{ candidate_index } .sqlite" )).unlink (missing_ok = True )
446- logger .info (f"Optimization candidate { candidate_index } /{ original_len } :" )
496+ logger .info (f"Optimization candidate { candidate_index } /{ processor . original_len } :" )
447497 code_print (candidate .source_code .flat )
448498 # map ast normalized code to diff len, unnormalized code
449499 # map opt id to the shortest unnormalized code
@@ -468,7 +518,7 @@ def determine_best_candidate(
468518 # check if this code has been evaluated before by checking the ast normalized code string
469519 normalized_code = ast .unparse (ast .parse (candidate .source_code .flat .strip ()))
470520 if normalized_code in ast_code_to_id :
471- logger .warning (
521+ logger .info (
472522 "Current candidate has been encountered before in testing, Skipping optimization candidate."
473523 )
474524 past_opt_id = ast_code_to_id [normalized_code ]["optimization_id" ]
@@ -746,7 +796,9 @@ def reformat_code_and_helpers(
746796 file_to_code_context = optimized_context .file_to_path ()
747797 optimized_code = file_to_code_context .get (str (path .relative_to (self .project_root )), "" )
748798
749- new_code = format_code (self .args .formatter_cmds , path , optimized_code = optimized_code , check_diff = True )
799+ new_code = format_code (
800+ self .args .formatter_cmds , path , optimized_code = optimized_code , check_diff = True , exit_on_failure = False
801+ )
750802 if should_sort_imports :
751803 new_code = sort_imports (new_code )
752804
@@ -755,7 +807,11 @@ def reformat_code_and_helpers(
755807 module_abspath = hp .file_path
756808 hp_source_code = hp .source_code
757809 formatted_helper_code = format_code (
758- self .args .formatter_cmds , module_abspath , optimized_code = hp_source_code , check_diff = True
810+ self .args .formatter_cmds ,
811+ module_abspath ,
812+ optimized_code = hp_source_code ,
813+ check_diff = True ,
814+ exit_on_failure = False ,
759815 )
760816 if should_sort_imports :
761817 formatted_helper_code = sort_imports (formatted_helper_code )
@@ -1153,7 +1209,6 @@ def find_and_process_best_optimization(
11531209 original_helper_code ,
11541210 code_context ,
11551211 )
1156- self .log_successful_optimization (explanation , generated_tests , exp_type )
11571212 return best_optimization
11581213
11591214 def process_review (
@@ -1227,7 +1282,10 @@ def process_review(
12271282 file_path = explanation .file_path ,
12281283 benchmark_details = explanation .benchmark_details ,
12291284 )
1230- console .print (Panel (new_explanation_raw_str , title = "Best Candidate Explanation" , border_style = "blue" ))
1285+ self .log_successful_optimization (new_explanation , generated_tests , exp_type )
1286+
1287+ best_optimization .explanation_v2 = new_explanation .explanation_message ()
1288+
12311289 data = {
12321290 "original_code" : original_code_combined ,
12331291 "new_code" : new_code_combined ,
@@ -1240,6 +1298,7 @@ def process_review(
12401298 "coverage_message" : coverage_message ,
12411299 "replay_tests" : replay_tests ,
12421300 "concolic_tests" : concolic_tests ,
1301+ "root_dir" : self .project_root ,
12431302 }
12441303
12451304 raise_pr = not self .args .no_pr
@@ -1248,13 +1307,35 @@ def process_review(
12481307 data ["git_remote" ] = self .args .git_remote
12491308 check_create_pr (** data )
12501309 elif self .args .staging_review :
1251- create_staging (** data )
1310+ response = create_staging (** data )
1311+ if response .status_code == 200 :
1312+ staging_url = f"https://app.codeflash.ai/review-optimizations/{ self .function_trace_id [:- 4 ] + exp_type if self .experiment_id else self .function_trace_id } "
1313+ console .print (
1314+ Panel (
1315+ f"[bold green]✅ Staging created:[/bold green]\n [link={ staging_url } ]{ staging_url } [/link]" ,
1316+ title = "Staging Link" ,
1317+ border_style = "green" ,
1318+ )
1319+ )
1320+ else :
1321+ console .print (
1322+ Panel (
1323+ f"[bold red]❌ Failed to create staging[/bold red]\n Status: { response .status_code } " ,
1324+ title = "Staging Error" ,
1325+ border_style = "red" ,
1326+ )
1327+ )
1328+
12521329 else :
12531330 # Mark optimization success since no PR will be created
12541331 mark_optimization_success (
12551332 trace_id = self .function_trace_id , is_optimization_found = best_optimization is not None
12561333 )
12571334
1335+ # If worktree mode, do not revert code and helpers,, otherwise we would have an empty diff when writing the patch in the lsp
1336+ if self .args .worktree :
1337+ return
1338+
12581339 if raise_pr and (
12591340 self .args .all
12601341 or env_utils .get_pr_number ()
0 commit comments