33import ast
44import concurrent .futures
55import os
6+ import queue
67import random
78import subprocess
89import time
910import uuid
10- from collections import defaultdict , deque
11+ from collections import defaultdict
1112from pathlib import Path
1213from typing import TYPE_CHECKING
1314
104105 from codeflash .verification .verification_utils import TestConfig
105106
106107
108+ class CandidateProcessor :
109+ """Handles candidate processing using a queue-based approach."""
110+
111+ def __init__ (
112+ self ,
113+ initial_candidates : list ,
114+ future_line_profile_results : concurrent .futures .Future ,
115+ future_all_refinements : list ,
116+ ) -> None :
117+ self .candidate_queue = queue .Queue ()
118+ self .line_profiler_done = False
119+ self .refinement_done = False
120+ self .candidate_len = len (initial_candidates )
121+
122+ # Initialize queue with initial candidates
123+ for candidate in initial_candidates :
124+ self .candidate_queue .put (candidate )
125+
126+ self .future_line_profile_results = future_line_profile_results
127+ self .future_all_refinements = future_all_refinements
128+
129+ def get_next_candidate (self ) -> OptimizedCandidate | None :
130+ """Get the next candidate from the queue, handling async results as needed."""
131+ try :
132+ return self .candidate_queue .get_nowait ()
133+ except queue .Empty :
134+ return self ._handle_empty_queue ()
135+
136+ def _handle_empty_queue (self ) -> OptimizedCandidate | None :
137+ """Handle empty queue by checking for pending async results."""
138+ if not self .line_profiler_done :
139+ return self ._process_line_profiler_results ()
140+ if self .line_profiler_done and not self .refinement_done :
141+ return self ._process_refinement_results ()
142+ return None # All done
143+
144+ def _process_line_profiler_results (self ) -> OptimizedCandidate | None :
145+ """Process line profiler results and add to queue."""
146+ logger .debug ("all candidates processed, await candidates from line profiler" )
147+ concurrent .futures .wait ([self .future_line_profile_results ])
148+ line_profile_results = self .future_line_profile_results .result ()
149+
150+ for candidate in line_profile_results :
151+ self .candidate_queue .put (candidate )
152+
153+ self .candidate_len += len (line_profile_results )
154+ logger .info (f"Added results from line profiler to candidates, total candidates now: { self .candidate_len } " )
155+ self .line_profiler_done = True
156+
157+ return self .get_next_candidate ()
158+
159+ def _process_refinement_results (self ) -> OptimizedCandidate | None :
160+ """Process refinement results and add to queue."""
161+ concurrent .futures .wait (self .future_all_refinements )
162+ refinement_response = []
163+
164+ for future_refinement in self .future_all_refinements :
165+ possible_refinement = future_refinement .result ()
166+ if len (possible_refinement ) > 0 :
167+ refinement_response .append (possible_refinement [0 ])
168+
169+ for candidate in refinement_response :
170+ self .candidate_queue .put (candidate )
171+
172+ self .candidate_len += len (refinement_response )
173+ logger .info (
174+ f"Added { len (refinement_response )} candidates from refinement, total candidates now: { self .candidate_len } "
175+ )
176+ self .refinement_done = True
177+
178+ return self .get_next_candidate ()
179+
180+ def is_done (self ) -> bool :
181+ """Check if processing is complete."""
182+ return self .line_profiler_done and self .refinement_done and self .candidate_queue .empty ()
183+
184+
107185class FunctionOptimizer :
108186 def __init__ (
109187 self ,
@@ -378,15 +456,13 @@ def determine_best_candidate(
378456 f"{ self .function_to_optimize .qualified_name } …"
379457 )
380458 console .rule ()
381- candidates = deque (candidates )
382- refinement_done = False
383- line_profiler_done = False
459+
384460 future_all_refinements : list [concurrent .futures .Future ] = []
385461 ast_code_to_id = {}
386462 valid_optimizations = []
387463 optimizations_post = {} # we need to overwrite some opt candidates' code strings as they are no longer evaluated, instead their shorter/longer versions might be evaluated
388- # Start a new thread for AI service request, start loop in main thread
389- # check if aiservice request is complete, when it is complete, append result to the candidates list
464+
465+ # Start a new thread for AI service request
390466 ai_service_client = self .aiservice_client if exp_type == "EXP0" else self .local_aiservice_client
391467 future_line_profile_results = self .executor .submit (
392468 ai_service_client .optimize_python_code_line_profiler ,
@@ -401,48 +477,23 @@ def determine_best_candidate(
401477 if self .experiment_id
402478 else None ,
403479 )
480+
481+ # Initialize candidate processor
482+ processor = CandidateProcessor (candidates , future_line_profile_results , future_all_refinements )
404483 candidate_index = 0
405- original_len = len (candidates )
406- # TODO : We need to rewrite this candidate loop as a class, the container which has candidates receives new candidates at unknown times due to the async nature of lp and refinement calls,
407- # TODO : in addition, the refinement calls depend on line profiler calls being complete so we need to check that reliably
408- while True :
409- try :
410- if len (candidates ) > 0 :
411- candidate = candidates .popleft ()
412- else :
413- if not line_profiler_done :
414- logger .debug ("all candidates processed, await candidates from line profiler" )
415- concurrent .futures .wait ([future_line_profile_results ])
416- line_profile_results = future_line_profile_results .result ()
417- candidates .extend (line_profile_results )
418- original_len += len (line_profile_results )
419- logger .info (
420- f"Added results from line profiler to candidates, total candidates now: { original_len } "
421- )
422- line_profiler_done = True
423- continue
424- if line_profiler_done and not refinement_done :
425- concurrent .futures .wait (future_all_refinements )
426- refinement_response = []
427- for future_refinement in future_all_refinements :
428- possible_refinement = future_refinement .result ()
429- if len (possible_refinement ) > 0 : # if the api returns a valid response
430- refinement_response .append (possible_refinement [0 ])
431- candidates .extend (refinement_response )
432- original_len += len (refinement_response )
433- logger .info (
434- f"Added { len (refinement_response )} candidates from refinement, total candidates now: { original_len } "
435- )
436- refinement_done = True
437- continue
438- if line_profiler_done and refinement_done :
439- logger .debug ("everything done, exiting" )
440- break
441484
485+ # Process candidates using queue-based approach
486+ while not processor .is_done ():
487+ candidate = processor .get_next_candidate ()
488+ if candidate is None :
489+ logger .debug ("everything done, exiting" )
490+ break
491+
492+ try :
442493 candidate_index += 1
443494 get_run_tmp_file (Path (f"test_return_values_{ candidate_index } .bin" )).unlink (missing_ok = True )
444495 get_run_tmp_file (Path (f"test_return_values_{ candidate_index } .sqlite" )).unlink (missing_ok = True )
445- logger .info (f"Optimization candidate { candidate_index } /{ original_len } :" )
496+ logger .info (f"Optimization candidate { candidate_index } /{ processor . candidate_len } :" )
446497 code_print (candidate .source_code .flat )
447498 # map ast normalized code to diff len, unnormalized code
448499 # map opt id to the shortest unnormalized code
@@ -467,7 +518,7 @@ def determine_best_candidate(
467518 # check if this code has been evaluated before by checking the ast normalized code string
468519 normalized_code = ast .unparse (ast .parse (candidate .source_code .flat .strip ()))
469520 if normalized_code in ast_code_to_id :
470- logger .warning (
521+ logger .info (
471522 "Current candidate has been encountered before in testing, Skipping optimization candidate."
472523 )
473524 past_opt_id = ast_code_to_id [normalized_code ]["optimization_id" ]
0 commit comments