Skip to content

Commit c8ae694

Browse files
authored
Merge pull request #758 from codeflash-ai/chore/asyncio-optimization
[Chore] Async implementation for perform_function_optimization
2 parents c09a334 + 33f15e6 commit c8ae694

File tree

3 files changed

+161
-111
lines changed

3 files changed

+161
-111
lines changed

codeflash/lsp/beta.py

Lines changed: 9 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
import contextlib
3+
import asyncio
44
import os
55
from dataclasses import dataclass
66
from pathlib import Path
@@ -11,16 +11,15 @@
1111

1212
from codeflash.api.cfapi import get_codeflash_api_key, get_user_id
1313
from codeflash.cli_cmds.cli import process_pyproject_config
14-
from codeflash.cli_cmds.console import code_print
1514
from codeflash.code_utils.git_utils import git_root_dir
16-
from codeflash.code_utils.git_worktree_utils import create_diff_patch_from_worktree
1715
from codeflash.code_utils.shell_utils import save_api_key_to_rc
1816
from codeflash.discovery.functions_to_optimize import (
1917
filter_functions,
2018
get_functions_inside_a_commit,
2119
get_functions_within_git_diff,
2220
)
2321
from codeflash.either import is_successful
22+
from codeflash.lsp.features.perform_optimization import sync_perform_optimization
2423
from codeflash.lsp.server import CodeflashLanguageServer
2524

2625
if TYPE_CHECKING:
@@ -71,7 +70,6 @@ class OptimizableFunctionsInCommitParams:
7170
commit_hash: str
7271

7372

74-
# server = CodeflashLanguageServer("codeflash-language-server", "v1.0", protocol_cls=CodeflashLanguageServerProtocol)
7573
server = CodeflashLanguageServer("codeflash-language-server", "v1.0")
7674

7775

@@ -339,115 +337,15 @@ def initialize_function_optimization(
339337

340338

341339
@server.feature("performFunctionOptimization")
342-
@server.thread()
343-
def perform_function_optimization(
340+
async def perform_function_optimization(
344341
server: CodeflashLanguageServer, params: FunctionOptimizationParams
345342
) -> dict[str, str]:
343+
loop = asyncio.get_running_loop()
346344
try:
347-
server.show_message_log(f"Starting optimization for function: {params.functionName}", "Info")
348-
should_run_experiment, code_context, original_helper_code = server.current_optimization_init_result
349-
function_optimizer = server.optimizer.current_function_optimizer
350-
current_function = function_optimizer.function_to_optimize
351-
352-
code_print(
353-
code_context.read_writable_code.flat,
354-
file_name=current_function.file_path,
355-
function_name=current_function.function_name,
356-
)
357-
358-
optimizable_funcs = {current_function.file_path: [current_function]}
359-
360-
devnull_writer = open(os.devnull, "w") # noqa
361-
with contextlib.redirect_stdout(devnull_writer):
362-
function_to_tests, _num_discovered_tests = server.optimizer.discover_tests(optimizable_funcs)
363-
function_optimizer.function_to_tests = function_to_tests
364-
365-
test_setup_result = function_optimizer.generate_and_instrument_tests(
366-
code_context, should_run_experiment=should_run_experiment
367-
)
368-
if not is_successful(test_setup_result):
369-
return {"functionName": params.functionName, "status": "error", "message": test_setup_result.failure()}
370-
(
371-
generated_tests,
372-
function_to_concolic_tests,
373-
concolic_test_str,
374-
optimizations_set,
375-
generated_test_paths,
376-
generated_perf_test_paths,
377-
instrumented_unittests_created_for_function,
378-
original_conftest_content,
379-
) = test_setup_result.unwrap()
380-
381-
baseline_setup_result = function_optimizer.setup_and_establish_baseline(
382-
code_context=code_context,
383-
original_helper_code=original_helper_code,
384-
function_to_concolic_tests=function_to_concolic_tests,
385-
generated_test_paths=generated_test_paths,
386-
generated_perf_test_paths=generated_perf_test_paths,
387-
instrumented_unittests_created_for_function=instrumented_unittests_created_for_function,
388-
original_conftest_content=original_conftest_content,
389-
)
390-
391-
if not is_successful(baseline_setup_result):
392-
return {"functionName": params.functionName, "status": "error", "message": baseline_setup_result.failure()}
393-
394-
(
395-
function_to_optimize_qualified_name,
396-
function_to_all_tests,
397-
original_code_baseline,
398-
test_functions_to_remove,
399-
file_path_to_helper_classes,
400-
) = baseline_setup_result.unwrap()
401-
402-
best_optimization = function_optimizer.find_and_process_best_optimization(
403-
optimizations_set=optimizations_set,
404-
code_context=code_context,
405-
original_code_baseline=original_code_baseline,
406-
original_helper_code=original_helper_code,
407-
file_path_to_helper_classes=file_path_to_helper_classes,
408-
function_to_optimize_qualified_name=function_to_optimize_qualified_name,
409-
function_to_all_tests=function_to_all_tests,
410-
generated_tests=generated_tests,
411-
test_functions_to_remove=test_functions_to_remove,
412-
concolic_test_str=concolic_test_str,
413-
)
414-
415-
if not best_optimization:
416-
server.show_message_log(
417-
f"No best optimizations found for function {function_to_optimize_qualified_name}", "Warning"
418-
)
419-
return {
420-
"functionName": params.functionName,
421-
"status": "error",
422-
"message": f"No best optimizations found for function {function_to_optimize_qualified_name}",
423-
}
424-
425-
# generate a patch for the optimization
426-
relative_file_paths = [code_string.file_path for code_string in code_context.read_writable_code.code_strings]
427-
428-
speedup = original_code_baseline.runtime / best_optimization.runtime
429-
430-
patch_path = create_diff_patch_from_worktree(
431-
server.optimizer.current_worktree, relative_file_paths, function_to_optimize_qualified_name
432-
)
433-
434-
if not patch_path:
435-
return {
436-
"functionName": params.functionName,
437-
"status": "error",
438-
"message": "Failed to create a patch for optimization",
439-
}
440-
441-
server.show_message_log(f"Optimization completed for {params.functionName} with {speedup:.2f}x speedup", "Info")
442-
443-
return {
444-
"functionName": params.functionName,
445-
"status": "success",
446-
"message": "Optimization completed successfully",
447-
"extra": f"Speedup: {speedup:.2f}x faster",
448-
"patch_file": str(patch_path),
449-
"task_id": params.task_id,
450-
"explanation": best_optimization.explanation_v2,
451-
}
345+
result = await loop.run_in_executor(None, sync_perform_optimization, server, params)
346+
except asyncio.CancelledError:
347+
return {"status": "canceled", "message": "Task was canceled"}
348+
else:
349+
return result
452350
finally:
453351
server.cleanup_the_optimizer()

codeflash/lsp/features/__init__.py

Whitespace-only changes.
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
import contextlib
2+
import os
3+
from pathlib import Path
4+
5+
from codeflash.cli_cmds.console import code_print
6+
from codeflash.code_utils.git_worktree_utils import create_diff_patch_from_worktree
7+
from codeflash.either import is_successful
8+
from codeflash.lsp.server import CodeflashLanguageServer
9+
10+
11+
# ruff: noqa: PLR0911, ANN001
12+
def sync_perform_optimization(server: CodeflashLanguageServer, params) -> dict[str, str]:
13+
server.show_message_log(f"Starting optimization for function: {params.functionName}", "Info")
14+
current_function = server.optimizer.current_function_being_optimized
15+
16+
if not current_function:
17+
server.show_message_log(f"No current function being optimized for {params.functionName}", "Error")
18+
return {
19+
"functionName": params.functionName,
20+
"status": "error",
21+
"message": "No function currently being optimized",
22+
}
23+
24+
module_prep_result = server.optimizer.prepare_module_for_optimization(current_function.file_path)
25+
if not module_prep_result:
26+
return {
27+
"functionName": params.functionName,
28+
"status": "error",
29+
"message": "Failed to prepare module for optimization",
30+
}
31+
32+
validated_original_code, original_module_ast = module_prep_result
33+
34+
function_optimizer = server.optimizer.create_function_optimizer(
35+
current_function,
36+
function_to_optimize_source_code=validated_original_code[current_function.file_path].source_code,
37+
original_module_ast=original_module_ast,
38+
original_module_path=current_function.file_path,
39+
function_to_tests={},
40+
)
41+
42+
server.optimizer.current_function_optimizer = function_optimizer
43+
if not function_optimizer:
44+
return {"functionName": params.functionName, "status": "error", "message": "No function optimizer found"}
45+
46+
initialization_result = function_optimizer.can_be_optimized()
47+
if not is_successful(initialization_result):
48+
return {"functionName": params.functionName, "status": "error", "message": initialization_result.failure()}
49+
50+
should_run_experiment, code_context, original_helper_code = initialization_result.unwrap()
51+
52+
code_print(
53+
code_context.read_writable_code.flat,
54+
file_name=current_function.file_path,
55+
function_name=current_function.function_name,
56+
)
57+
58+
optimizable_funcs = {current_function.file_path: [current_function]}
59+
60+
devnull_writer = open(os.devnull, "w") # noqa
61+
with contextlib.redirect_stdout(devnull_writer):
62+
function_to_tests, _num_discovered_tests = server.optimizer.discover_tests(optimizable_funcs)
63+
function_optimizer.function_to_tests = function_to_tests
64+
65+
test_setup_result = function_optimizer.generate_and_instrument_tests(
66+
code_context, should_run_experiment=should_run_experiment
67+
)
68+
if not is_successful(test_setup_result):
69+
return {"functionName": params.functionName, "status": "error", "message": test_setup_result.failure()}
70+
(
71+
generated_tests,
72+
function_to_concolic_tests,
73+
concolic_test_str,
74+
optimizations_set,
75+
generated_test_paths,
76+
generated_perf_test_paths,
77+
instrumented_unittests_created_for_function,
78+
original_conftest_content,
79+
) = test_setup_result.unwrap()
80+
81+
baseline_setup_result = function_optimizer.setup_and_establish_baseline(
82+
code_context=code_context,
83+
original_helper_code=original_helper_code,
84+
function_to_concolic_tests=function_to_concolic_tests,
85+
generated_test_paths=generated_test_paths,
86+
generated_perf_test_paths=generated_perf_test_paths,
87+
instrumented_unittests_created_for_function=instrumented_unittests_created_for_function,
88+
original_conftest_content=original_conftest_content,
89+
)
90+
91+
if not is_successful(baseline_setup_result):
92+
return {"functionName": params.functionName, "status": "error", "message": baseline_setup_result.failure()}
93+
94+
(
95+
function_to_optimize_qualified_name,
96+
function_to_all_tests,
97+
original_code_baseline,
98+
test_functions_to_remove,
99+
file_path_to_helper_classes,
100+
) = baseline_setup_result.unwrap()
101+
102+
best_optimization = function_optimizer.find_and_process_best_optimization(
103+
optimizations_set=optimizations_set,
104+
code_context=code_context,
105+
original_code_baseline=original_code_baseline,
106+
original_helper_code=original_helper_code,
107+
file_path_to_helper_classes=file_path_to_helper_classes,
108+
function_to_optimize_qualified_name=function_to_optimize_qualified_name,
109+
function_to_all_tests=function_to_all_tests,
110+
generated_tests=generated_tests,
111+
test_functions_to_remove=test_functions_to_remove,
112+
concolic_test_str=concolic_test_str,
113+
)
114+
115+
if not best_optimization:
116+
server.show_message_log(
117+
f"No best optimizations found for function {function_to_optimize_qualified_name}", "Warning"
118+
)
119+
return {
120+
"functionName": params.functionName,
121+
"status": "error",
122+
"message": f"No best optimizations found for function {function_to_optimize_qualified_name}",
123+
}
124+
# generate a patch for the optimization
125+
relative_file_paths = [code_string.file_path for code_string in code_context.read_writable_code.code_strings]
126+
speedup = original_code_baseline.runtime / best_optimization.runtime
127+
# get the original file path in the actual project (not in the worktree)
128+
original_args, _ = server.optimizer.original_args_and_test_cfg
129+
relative_file_path = current_function.file_path.relative_to(server.optimizer.current_worktree)
130+
original_file_path = Path(original_args.project_root / relative_file_path).resolve()
131+
132+
metadata = create_diff_patch_from_worktree(
133+
server.optimizer.current_worktree,
134+
relative_file_paths,
135+
metadata_input={
136+
"fto_name": function_to_optimize_qualified_name,
137+
"explanation": best_optimization.explanation_v2,
138+
"file_path": str(original_file_path),
139+
"speedup": speedup,
140+
},
141+
)
142+
143+
server.show_message_log(f"Optimization completed for {params.functionName} with {speedup:.2f}x speedup", "Info")
144+
return {
145+
"functionName": params.functionName,
146+
"status": "success",
147+
"message": "Optimization completed successfully",
148+
"extra": f"Speedup: {speedup:.2f}x faster",
149+
"patch_file": metadata["patch_path"],
150+
"patch_id": metadata["id"],
151+
"explanation": best_optimization.explanation_v2,
152+
}

0 commit comments

Comments
 (0)