Skip to content

Commit 8de932e

Browse files
committed
code review changes
1 parent 3ba531e commit 8de932e

File tree

2 files changed

+26
-30
lines changed

2 files changed

+26
-30
lines changed

codeflash/api/aiservice.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -360,21 +360,13 @@ def generate_ranking( # noqa: D417
360360
361361
Parameters
362362
----------
363-
- source_code (str): The python code to optimize.
364-
- optimized_code (str): The python code generated by the AI service.
365-
- dependency_code (str): The dependency code used as read-only context for the optimization
366-
- original_line_profiler_results: str - line profiler results for the baseline code
367-
- optimized_line_profiler_results: str - line profiler results for the optimized code
368-
- original_code_runtime: str - runtime for the baseline code
369-
- optimized_code_runtime: str - runtime for the optimized code
370-
- speedup: str - speedup of the optimized code
371-
- annotated_tests: str - test functions annotated with runtime
372-
- optimization_id: str - unique id of opt candidate
373-
- original_explanation: str - original_explanation generated for the opt candidate
363+
- trace_id : unique uuid of function
364+
- diffs : list of unified diff strings of opt candidates
365+
- speedups : list of speedups of opt candidates
374366
375367
Returns
376368
-------
377-
- List[OptimizationCandidate]: A list of Optimization Candidates.
369+
- List[int]: Ranking of opt candidates in decreasing order
378370
379371
"""
380372
payload = {

codeflash/optimization/function_optimizer.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -539,7 +539,9 @@ def determine_best_candidate(
539539
].markdown
540540
optimizations_post[past_opt_id] = ast_code_to_id[normalized_code]["shorter_source_code"].markdown
541541
new_diff_len = diff_length(candidate.source_code.flat, code_context.read_writable_code.flat)
542-
if new_diff_len < ast_code_to_id[normalized_code]["diff_len"]:
542+
if (
543+
new_diff_len < ast_code_to_id[normalized_code]["diff_len"]
544+
): # new candidate has a shorter diff than the previously encountered one
543545
ast_code_to_id[normalized_code]["shorter_source_code"] = candidate.source_code
544546
ast_code_to_id[normalized_code]["diff_len"] = new_diff_len
545547
continue
@@ -699,24 +701,26 @@ def determine_best_candidate(
699701
)
700702
optimization_ids.append(new_best_opt.candidate.optimization_id)
701703
runtimes_list.append(new_best_opt.runtime)
702-
future_ranking = self.executor.submit(
703-
ai_service_client.generate_ranking,
704-
diffs=diff_strs,
705-
optimization_ids=optimization_ids,
706-
speedups=speedups_list,
707-
trace_id=self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id,
708-
)
709-
concurrent.futures.wait([future_ranking])
710-
ranking = future_ranking.result()
711-
if ranking:
712-
ranking = [x - 1 for x in ranking]
713-
min_key = ranking[0]
704+
if len(optimization_ids) > 1:
705+
future_ranking = self.executor.submit(
706+
ai_service_client.generate_ranking,
707+
diffs=diff_strs,
708+
optimization_ids=optimization_ids,
709+
speedups=speedups_list,
710+
trace_id=self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id,
711+
)
712+
concurrent.futures.wait([future_ranking])
713+
ranking = future_ranking.result()
714+
if ranking:
715+
min_key = ranking[0]
716+
else:
717+
diff_lens_ranking = create_rank_dictionary_compact(diff_lens_list)
718+
runtimes_ranking = create_rank_dictionary_compact(runtimes_list)
719+
# TODO: better way to resolve conflicts with same min ranking
720+
overall_ranking = {key: diff_lens_ranking[key] + runtimes_ranking[key] for key in diff_lens_ranking}
721+
min_key = min(overall_ranking, key=overall_ranking.get)
714722
else:
715-
diff_lens_ranking = create_rank_dictionary_compact(diff_lens_list)
716-
runtimes_ranking = create_rank_dictionary_compact(runtimes_list)
717-
# TODO: better way to resolve conflicts with same min ranking
718-
overall_ranking = {key: diff_lens_ranking[key] + runtimes_ranking[key] for key in diff_lens_ranking.keys()} # noqa: SIM118
719-
min_key = min(overall_ranking, key=overall_ranking.get)
723+
min_key = 0 # only one candidate in valid _opts, already returns if there are no valid candidates
720724
best_optimization = valid_candidates_with_shorter_code[min_key]
721725
# reassign code string which is the shortest
722726
ai_service_client.log_results(

0 commit comments

Comments
 (0)