|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | | -import contextlib |
| 3 | +import asyncio |
4 | 4 | import os |
5 | 5 | from dataclasses import dataclass |
6 | 6 | from pathlib import Path |
7 | 7 | from typing import TYPE_CHECKING, Optional |
8 | 8 |
|
9 | | -import git |
10 | 9 | from pygls import uris |
11 | 10 |
|
12 | 11 | from codeflash.api.cfapi import get_codeflash_api_key, get_user_id |
13 | 12 | from codeflash.cli_cmds.cli import process_pyproject_config |
14 | | -from codeflash.cli_cmds.console import code_print |
15 | 13 | from codeflash.code_utils.git_utils import git_root_dir |
16 | | -from codeflash.code_utils.git_worktree_utils import create_diff_patch_from_worktree |
17 | 14 | from codeflash.code_utils.shell_utils import save_api_key_to_rc |
18 | 15 | from codeflash.discovery.functions_to_optimize import ( |
19 | 16 | filter_functions, |
20 | 17 | get_functions_inside_a_commit, |
21 | 18 | get_functions_within_git_diff, |
22 | 19 | ) |
23 | 20 | from codeflash.either import is_successful |
| 21 | +from codeflash.lsp.features.perform_optimization import sync_perform_optimization |
24 | 22 | from codeflash.lsp.server import CodeflashLanguageServer |
25 | 23 |
|
26 | 24 | if TYPE_CHECKING: |
@@ -71,7 +69,6 @@ class OptimizableFunctionsInCommitParams: |
71 | 69 | commit_hash: str |
72 | 70 |
|
73 | 71 |
|
74 | | -# server = CodeflashLanguageServer("codeflash-language-server", "v1.0", protocol_cls=CodeflashLanguageServerProtocol) |
75 | 72 | server = CodeflashLanguageServer("codeflash-language-server", "v1.0") |
76 | 73 |
|
77 | 74 |
|
@@ -157,11 +154,13 @@ def _find_pyproject_toml(workspace_path: str) -> tuple[Path | None, bool]: |
157 | 154 |
|
158 | 155 | # should be called the first thing to initialize and validate the project |
159 | 156 | @server.feature("initProject") |
160 | | -def init_project(server: CodeflashLanguageServer, params: ValidateProjectParams) -> dict[str, str]: # noqa: PLR0911 |
| 157 | +def init_project(server: CodeflashLanguageServer, params: ValidateProjectParams) -> dict[str, str]: |
161 | 158 | from codeflash.cli_cmds.cmd_init import is_valid_pyproject_toml |
162 | 159 |
|
163 | | - pyproject_toml_path: Path | None = getattr(params, "config_file", None) or getattr(server.args, "config_file", None) |
| 160 | + # Always process args in the init project, the extension can call |
| 161 | + server.args_processed_before = False |
164 | 162 |
|
| 163 | + pyproject_toml_path: Path | None = getattr(params, "config_file", None) or getattr(server.args, "config_file", None) |
165 | 164 | if pyproject_toml_path is not None: |
166 | 165 | # if there is a config file provided use it |
167 | 166 | server.prepare_optimizer_arguments(pyproject_toml_path) |
@@ -192,20 +191,12 @@ def init_project(server: CodeflashLanguageServer, params: ValidateProjectParams) |
192 | 191 | } |
193 | 192 |
|
194 | 193 | server.show_message_log("Validating project...", "Info") |
195 | | - config = is_valid_pyproject_toml(pyproject_toml_path) |
| 194 | + config, reason = is_valid_pyproject_toml(pyproject_toml_path) |
196 | 195 | if config is None: |
197 | 196 | server.show_message_log("pyproject.toml is not valid", "Error") |
198 | | - return {"status": "error", "message": "not valid", "pyprojectPath": pyproject_toml_path} |
| 197 | + return {"status": "error", "message": f"reason: {reason}", "pyprojectPath": pyproject_toml_path} |
199 | 198 |
|
200 | 199 | args = process_args(server) |
201 | | - repo = git.Repo(args.module_root, search_parent_directories=True) |
202 | | - if repo.bare: |
203 | | - return {"status": "error", "message": "Repository is in bare state"} |
204 | | - |
205 | | - try: |
206 | | - _ = repo.head.commit |
207 | | - except Exception: |
208 | | - return {"status": "error", "message": "Repository has no commits (unborn HEAD)"} |
209 | 200 |
|
210 | 201 | return {"status": "success", "moduleRoot": args.module_root, "pyprojectPath": pyproject_toml_path, "root": root} |
211 | 202 |
|
@@ -339,115 +330,15 @@ def initialize_function_optimization( |
339 | 330 |
|
340 | 331 |
|
341 | 332 | @server.feature("performFunctionOptimization") |
342 | | -@server.thread() |
343 | | -def perform_function_optimization( |
| 333 | +async def perform_function_optimization( |
344 | 334 | server: CodeflashLanguageServer, params: FunctionOptimizationParams |
345 | 335 | ) -> dict[str, str]: |
| 336 | + loop = asyncio.get_running_loop() |
346 | 337 | try: |
347 | | - server.show_message_log(f"Starting optimization for function: {params.functionName}", "Info") |
348 | | - should_run_experiment, code_context, original_helper_code = server.current_optimization_init_result |
349 | | - function_optimizer = server.optimizer.current_function_optimizer |
350 | | - current_function = function_optimizer.function_to_optimize |
351 | | - |
352 | | - code_print( |
353 | | - code_context.read_writable_code.flat, |
354 | | - file_name=current_function.file_path, |
355 | | - function_name=current_function.function_name, |
356 | | - ) |
357 | | - |
358 | | - optimizable_funcs = {current_function.file_path: [current_function]} |
359 | | - |
360 | | - devnull_writer = open(os.devnull, "w") # noqa |
361 | | - with contextlib.redirect_stdout(devnull_writer): |
362 | | - function_to_tests, _num_discovered_tests = server.optimizer.discover_tests(optimizable_funcs) |
363 | | - function_optimizer.function_to_tests = function_to_tests |
364 | | - |
365 | | - test_setup_result = function_optimizer.generate_and_instrument_tests( |
366 | | - code_context, should_run_experiment=should_run_experiment |
367 | | - ) |
368 | | - if not is_successful(test_setup_result): |
369 | | - return {"functionName": params.functionName, "status": "error", "message": test_setup_result.failure()} |
370 | | - ( |
371 | | - generated_tests, |
372 | | - function_to_concolic_tests, |
373 | | - concolic_test_str, |
374 | | - optimizations_set, |
375 | | - generated_test_paths, |
376 | | - generated_perf_test_paths, |
377 | | - instrumented_unittests_created_for_function, |
378 | | - original_conftest_content, |
379 | | - ) = test_setup_result.unwrap() |
380 | | - |
381 | | - baseline_setup_result = function_optimizer.setup_and_establish_baseline( |
382 | | - code_context=code_context, |
383 | | - original_helper_code=original_helper_code, |
384 | | - function_to_concolic_tests=function_to_concolic_tests, |
385 | | - generated_test_paths=generated_test_paths, |
386 | | - generated_perf_test_paths=generated_perf_test_paths, |
387 | | - instrumented_unittests_created_for_function=instrumented_unittests_created_for_function, |
388 | | - original_conftest_content=original_conftest_content, |
389 | | - ) |
390 | | - |
391 | | - if not is_successful(baseline_setup_result): |
392 | | - return {"functionName": params.functionName, "status": "error", "message": baseline_setup_result.failure()} |
393 | | - |
394 | | - ( |
395 | | - function_to_optimize_qualified_name, |
396 | | - function_to_all_tests, |
397 | | - original_code_baseline, |
398 | | - test_functions_to_remove, |
399 | | - file_path_to_helper_classes, |
400 | | - ) = baseline_setup_result.unwrap() |
401 | | - |
402 | | - best_optimization = function_optimizer.find_and_process_best_optimization( |
403 | | - optimizations_set=optimizations_set, |
404 | | - code_context=code_context, |
405 | | - original_code_baseline=original_code_baseline, |
406 | | - original_helper_code=original_helper_code, |
407 | | - file_path_to_helper_classes=file_path_to_helper_classes, |
408 | | - function_to_optimize_qualified_name=function_to_optimize_qualified_name, |
409 | | - function_to_all_tests=function_to_all_tests, |
410 | | - generated_tests=generated_tests, |
411 | | - test_functions_to_remove=test_functions_to_remove, |
412 | | - concolic_test_str=concolic_test_str, |
413 | | - ) |
414 | | - |
415 | | - if not best_optimization: |
416 | | - server.show_message_log( |
417 | | - f"No best optimizations found for function {function_to_optimize_qualified_name}", "Warning" |
418 | | - ) |
419 | | - return { |
420 | | - "functionName": params.functionName, |
421 | | - "status": "error", |
422 | | - "message": f"No best optimizations found for function {function_to_optimize_qualified_name}", |
423 | | - } |
424 | | - |
425 | | - # generate a patch for the optimization |
426 | | - relative_file_paths = [code_string.file_path for code_string in code_context.read_writable_code.code_strings] |
427 | | - |
428 | | - speedup = original_code_baseline.runtime / best_optimization.runtime |
429 | | - |
430 | | - patch_path = create_diff_patch_from_worktree( |
431 | | - server.optimizer.current_worktree, relative_file_paths, function_to_optimize_qualified_name |
432 | | - ) |
433 | | - |
434 | | - if not patch_path: |
435 | | - return { |
436 | | - "functionName": params.functionName, |
437 | | - "status": "error", |
438 | | - "message": "Failed to create a patch for optimization", |
439 | | - } |
440 | | - |
441 | | - server.show_message_log(f"Optimization completed for {params.functionName} with {speedup:.2f}x speedup", "Info") |
442 | | - |
443 | | - return { |
444 | | - "functionName": params.functionName, |
445 | | - "status": "success", |
446 | | - "message": "Optimization completed successfully", |
447 | | - "extra": f"Speedup: {speedup:.2f}x faster", |
448 | | - "patch_file": str(patch_path), |
449 | | - "task_id": params.task_id, |
450 | | - "explanation": best_optimization.explanation_v2, |
451 | | - } |
| 338 | + result = await loop.run_in_executor(None, sync_perform_optimization, server, params) |
| 339 | + except asyncio.CancelledError: |
| 340 | + return {"status": "canceled", "message": "Task was canceled"} |
| 341 | + else: |
| 342 | + return result |
452 | 343 | finally: |
453 | 344 | server.cleanup_the_optimizer() |
0 commit comments