Skip to content

Commit 2802ae6

Browse files
Merge pull request #717 from codeflash-ai/ranker
Ranking different Optimization candidates based on their speedup and diff
2 parents 4bcd81c + bdd1461 commit 2802ae6

File tree

5 files changed

+136
-43
lines changed

5 files changed

+136
-43
lines changed

codeflash/api/aiservice.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,51 @@ def get_new_explanation( # noqa: D417
353353
console.rule()
354354
return ""
355355

356+
def generate_ranking( # noqa: D417
357+
self, trace_id: str, diffs: list[str], optimization_ids: list[str], speedups: list[float]
358+
) -> list[int] | None:
359+
"""Optimize the given python code for performance by making a request to the Django endpoint.
360+
361+
Parameters
362+
----------
363+
- trace_id : unique uuid of function
364+
- diffs : list of unified diff strings of opt candidates
365+
- speedups : list of speedups of opt candidates
366+
367+
Returns
368+
-------
369+
- List[int]: Ranking of opt candidates in decreasing order
370+
371+
"""
372+
payload = {
373+
"trace_id": trace_id,
374+
"diffs": diffs,
375+
"speedups": speedups,
376+
"optimization_ids": optimization_ids,
377+
"python_version": platform.python_version(),
378+
}
379+
logger.info("Generating ranking")
380+
console.rule()
381+
try:
382+
response = self.make_ai_service_request("/rank", payload=payload, timeout=60)
383+
except requests.exceptions.RequestException as e:
384+
logger.exception(f"Error generating ranking: {e}")
385+
ph("cli-optimize-error-caught", {"error": str(e)})
386+
return None
387+
388+
if response.status_code == 200:
389+
ranking: list[int] = response.json()["ranking"]
390+
console.rule()
391+
return ranking
392+
try:
393+
error = response.json()["error"]
394+
except Exception:
395+
error = response.text
396+
logger.error(f"Error generating ranking: {response.status_code} - {error}")
397+
ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error})
398+
console.rule()
399+
return None
400+
356401
def log_results( # noqa: D417
357402
self,
358403
function_trace_id: str,

codeflash/code_utils/code_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,23 @@
2020
ImportErrorPattern = re.compile(r"ModuleNotFoundError.*$", re.MULTILINE)
2121

2222

23+
def unified_diff_strings(code1: str, code2: str, fromfile: str = "original", tofile: str = "modified") -> str:
24+
"""Return the unified diff between two code strings as a single string.
25+
26+
:param code1: First code string (original).
27+
:param code2: Second code string (modified).
28+
:param fromfile: Label for the first code string.
29+
:param tofile: Label for the second code string.
30+
:return: Unified diff as a string.
31+
"""
32+
code1_lines = code1.splitlines(keepends=True)
33+
code2_lines = code2.splitlines(keepends=True)
34+
35+
diff = difflib.unified_diff(code1_lines, code2_lines, fromfile=fromfile, tofile=tofile, lineterm="")
36+
37+
return "".join(diff)
38+
39+
2340
def diff_length(a: str, b: str) -> int:
2441
"""Compute the length (in characters) of the unified diff between two strings.
2542

codeflash/code_utils/version_check.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@
77
import requests
88
from packaging import version
99

10-
from codeflash.cli_cmds.console import console, logger
10+
from codeflash.cli_cmds.console import logger
1111
from codeflash.version import __version__
1212

1313
# Simple cache to avoid checking too frequently
14-
_version_cache = {"version": '0.0.0', "timestamp": float(0)}
14+
_version_cache = {"version": "0.0.0", "timestamp": float(0)}
1515
_cache_duration = 3600 # 1 hour cache
1616

1717

@@ -69,10 +69,8 @@ def check_for_newer_minor_version() -> None:
6969

7070
# Check if there's a newer minor version available
7171
# We only notify for minor version updates, not patch updates
72-
if latest_parsed > current_parsed: # < > == operators can be directly applied on version objects
73-
logger.warning(
74-
f"A newer version({latest_version}) of Codeflash is available, please update soon!"
75-
)
72+
if latest_parsed > current_parsed: # < > == operators can be directly applied on version objects
73+
logger.warning(f"A newer version({latest_version}) of Codeflash is available, please update soon!")
7674

7775
except version.InvalidVersion as e:
7876
logger.debug(f"Invalid version format: {e}")

codeflash/optimization/function_optimizer.py

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
has_any_async_functions,
4040
module_name_from_file_path,
4141
restore_conftest,
42+
unified_diff_strings,
4243
)
4344
from codeflash.code_utils.config_consts import (
4445
INDIVIDUAL_TESTCASE_TIMEOUT,
@@ -171,9 +172,10 @@ def _process_refinement_results(self) -> OptimizedCandidate | None:
171172
self.candidate_queue.put(candidate)
172173

173174
self.candidate_len += len(refinement_response)
174-
logger.info(
175-
f"Added {len(refinement_response)} candidates from refinement, total candidates now: {self.candidate_len}"
176-
)
175+
if len(refinement_response) > 0:
176+
logger.info(
177+
f"Added {len(refinement_response)} candidates from refinement, total candidates now: {self.candidate_len}"
178+
)
177179
self.refinement_done = True
178180

179181
return self.get_next_candidate()
@@ -537,7 +539,9 @@ def determine_best_candidate(
537539
].markdown
538540
optimizations_post[past_opt_id] = ast_code_to_id[normalized_code]["shorter_source_code"].markdown
539541
new_diff_len = diff_length(candidate.source_code.flat, code_context.read_writable_code.flat)
540-
if new_diff_len < ast_code_to_id[normalized_code]["diff_len"]:
542+
if (
543+
new_diff_len < ast_code_to_id[normalized_code]["diff_len"]
544+
): # new candidate has a shorter diff than the previously encountered one
541545
ast_code_to_id[normalized_code]["shorter_source_code"] = candidate.source_code
542546
ast_code_to_id[normalized_code]["diff_len"] = new_diff_len
543547
continue
@@ -660,6 +664,9 @@ def determine_best_candidate(
660664
# reassign the shorter code here
661665
valid_candidates_with_shorter_code = []
662666
diff_lens_list = [] # character level diff
667+
speedups_list = []
668+
optimization_ids = []
669+
diff_strs = []
663670
runtimes_list = []
664671
for valid_opt in valid_optimizations:
665672
valid_opt_normalized_code = ast.unparse(ast.parse(valid_opt.candidate.source_code.flat.strip()))
@@ -683,12 +690,39 @@ def determine_best_candidate(
683690
diff_lens_list.append(
684691
diff_length(new_best_opt.candidate.source_code.flat, code_context.read_writable_code.flat)
685692
) # char level diff
693+
diff_strs.append(
694+
unified_diff_strings(code_context.read_writable_code.flat, new_best_opt.candidate.source_code.flat)
695+
)
696+
speedups_list.append(
697+
1
698+
+ performance_gain(
699+
original_runtime_ns=original_code_baseline.runtime, optimized_runtime_ns=new_best_opt.runtime
700+
)
701+
)
702+
optimization_ids.append(new_best_opt.candidate.optimization_id)
686703
runtimes_list.append(new_best_opt.runtime)
687-
diff_lens_ranking = create_rank_dictionary_compact(diff_lens_list)
688-
runtimes_ranking = create_rank_dictionary_compact(runtimes_list)
689-
# TODO: better way to resolve conflicts with same min ranking
690-
overall_ranking = {key: diff_lens_ranking[key] + runtimes_ranking[key] for key in diff_lens_ranking.keys()} # noqa: SIM118
691-
min_key = min(overall_ranking, key=overall_ranking.get)
704+
if len(optimization_ids) > 1:
705+
future_ranking = self.executor.submit(
706+
ai_service_client.generate_ranking,
707+
diffs=diff_strs,
708+
optimization_ids=optimization_ids,
709+
speedups=speedups_list,
710+
trace_id=self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id,
711+
)
712+
concurrent.futures.wait([future_ranking])
713+
ranking = future_ranking.result()
714+
if ranking:
715+
min_key = ranking[0]
716+
else:
717+
diff_lens_ranking = create_rank_dictionary_compact(diff_lens_list)
718+
runtimes_ranking = create_rank_dictionary_compact(runtimes_list)
719+
# TODO: better way to resolve conflicts with same min ranking
720+
overall_ranking = {key: diff_lens_ranking[key] + runtimes_ranking[key] for key in diff_lens_ranking}
721+
min_key = min(overall_ranking, key=overall_ranking.get)
722+
elif len(optimization_ids) == 1:
723+
min_key = 0 # only one candidate in valid _opts, already returns if there are no valid candidates
724+
else: # 0? shouldn't happen but it's there to escape potential bugs
725+
return None
692726
best_optimization = valid_candidates_with_shorter_code[min_key]
693727
# reassign code string which is the shortest
694728
ai_service_client.log_results(

tests/test_version_check.py

Lines changed: 27 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -120,87 +120,86 @@ def test_get_latest_version_from_pypi_cache_expiry(self, mock_get):
120120
self.assertEqual(mock_get.call_count, 2)
121121

122122
@patch('codeflash.code_utils.version_check.get_latest_version_from_pypi')
123-
@patch('codeflash.code_utils.version_check.console')
123+
@patch('codeflash.code_utils.version_check.logger')
124124
@patch('codeflash.code_utils.version_check.__version__', '1.0.0')
125-
def test_check_for_newer_minor_version_newer_available(self, mock_console, mock_get_version):
125+
def test_check_for_newer_minor_version_newer_available(self, mock_logger,mock_get_version):
126126
"""Test warning message when newer minor version is available."""
127127
mock_get_version.return_value = "1.1.0"
128128

129129
check_for_newer_minor_version()
130130

131-
mock_console.print.assert_called_once()
132-
call_args = mock_console.print.call_args[0][0]
133-
self.assertIn("ℹ️ A newer version of Codeflash is available!", call_args)
134-
self.assertIn("Current version: 1.0.0", call_args)
135-
self.assertIn("Latest version: 1.1.0", call_args)
131+
mock_logger.warning.assert_called_once()
132+
call_args = mock_logger.warning.call_args[0][0]
133+
self.assertIn("of Codeflash is available, please update soon!", call_args)
134+
self.assertIn("1.1.0", call_args)
136135

137136
@patch('codeflash.code_utils.version_check.get_latest_version_from_pypi')
138-
@patch('codeflash.code_utils.version_check.console')
137+
@patch('codeflash.code_utils.version_check.logger')
139138
@patch('codeflash.code_utils.version_check.__version__', '1.0.0')
140-
def test_check_for_newer_minor_version_newer_major_available(self, mock_console, mock_get_version):
139+
def test_check_for_newer_minor_version_newer_major_available(self, mock_logger,mock_get_version):
141140
"""Test warning message when newer major version is available."""
142141
mock_get_version.return_value = "2.0.0"
143142

144143
check_for_newer_minor_version()
145144

146-
mock_console.print.assert_called_once()
147-
call_args = mock_console.print.call_args[0][0]
148-
self.assertIn("ℹ️ A newer version of Codeflash is available!", call_args)
145+
mock_logger.warning.assert_called_once()
146+
call_args = mock_logger.warning.call_args[0][0]
147+
self.assertIn("of Codeflash is available, please update soon!", call_args)
149148

150149
@patch('codeflash.code_utils.version_check.get_latest_version_from_pypi')
151-
@patch('codeflash.code_utils.version_check.console')
150+
@patch('codeflash.code_utils.version_check.logger')
152151
@patch('codeflash.code_utils.version_check.__version__', '1.1.0')
153-
def test_check_for_newer_minor_version_no_newer_available(self, mock_console, mock_get_version):
152+
def test_check_for_newer_minor_version_no_newer_available(self, mock_logger,mock_get_version):
154153
"""Test no warning when no newer version is available."""
155154
mock_get_version.return_value = "1.0.0"
156155

157156
check_for_newer_minor_version()
158157

159-
mock_console.print.assert_not_called()
158+
mock_logger.warning.assert_not_called()
160159

161160
@patch('codeflash.code_utils.version_check.get_latest_version_from_pypi')
162-
@patch('codeflash.code_utils.version_check.console')
163-
@patch('codeflash.code_utils.version_check.__version__', '1.0.0')
164-
def test_check_for_newer_minor_version_patch_update_ignored(self, mock_console, mock_get_version):
161+
@patch('codeflash.code_utils.version_check.logger')
162+
@patch('codeflash.code_utils.version_check.__version__', '1.0.1')
163+
def test_check_for_newer_minor_version_patch_update_ignored(self, mock_logger,mock_get_version):
165164
"""Test that patch updates don't trigger warnings."""
166165
mock_get_version.return_value = "1.0.1"
167166

168167
check_for_newer_minor_version()
169168

170-
mock_console.print.assert_not_called()
169+
mock_logger.warning.assert_not_called()
171170

172171
@patch('codeflash.code_utils.version_check.get_latest_version_from_pypi')
173-
@patch('codeflash.code_utils.version_check.console')
172+
@patch('codeflash.code_utils.version_check.logger')
174173
@patch('codeflash.code_utils.version_check.__version__', '1.0.0')
175-
def test_check_for_newer_minor_version_same_version(self, mock_console, mock_get_version):
174+
def test_check_for_newer_minor_version_same_version(self, mock_logger,mock_get_version):
176175
"""Test no warning when versions are the same."""
177176
mock_get_version.return_value = "1.0.0"
178177

179178
check_for_newer_minor_version()
180179

181-
mock_console.print.assert_not_called()
180+
mock_logger.warning.assert_not_called()
182181

183182
@patch('codeflash.code_utils.version_check.get_latest_version_from_pypi')
184-
@patch('codeflash.code_utils.version_check.console')
183+
@patch('codeflash.code_utils.version_check.logger')
185184
@patch('codeflash.code_utils.version_check.__version__', '1.0.0')
186-
def test_check_for_newer_minor_version_no_latest_version(self, mock_console, mock_get_version):
185+
def test_check_for_newer_minor_version_no_latest_version(self, mock_logger,mock_get_version):
187186
"""Test no warning when latest version cannot be fetched."""
188187
mock_get_version.return_value = None
189188

190189
check_for_newer_minor_version()
191190

192-
mock_console.print.assert_not_called()
191+
mock_logger.warning.assert_not_called()
193192

194193
@patch('codeflash.code_utils.version_check.get_latest_version_from_pypi')
195-
@patch('codeflash.code_utils.version_check.console')
194+
@patch('codeflash.code_utils.version_check.logger')
196195
@patch('codeflash.code_utils.version_check.__version__', '1.0.0')
197-
def test_check_for_newer_minor_version_invalid_version_format(self, mock_console, mock_get_version):
196+
def test_check_for_newer_minor_version_invalid_version_format(self, mock_logger,mock_get_version):
198197
"""Test handling of invalid version format."""
199198
mock_get_version.return_value = "invalid-version"
200199

201200
check_for_newer_minor_version()
202201

203-
mock_console.print.assert_not_called()
202+
mock_logger.warning.assert_not_called()
204203

205204

206205

0 commit comments

Comments
 (0)