Skip to content

Commit e84e57a

Browse files
[LSP] search for the nearest pyproject.toml file in the init feature rather than the initialization feature (#743)
* move searching for pyproject file from the server initialization to the validation feature * cleanup * formatting
1 parent d3e6427 commit e84e57a

File tree

4 files changed

+64
-38
lines changed

4 files changed

+64
-38
lines changed

codeflash/code_utils/env_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def get_codeflash_api_key() -> str:
4242
if env_api_key and not shell_api_key:
4343
try:
4444
from codeflash.either import is_successful
45+
4546
result = save_api_key_to_rc(env_api_key)
4647
if is_successful(result):
4748
logger.debug(f"Automatically saved API key from environment to shell config: {result.unwrap()}")

codeflash/lsp/beta.py

Lines changed: 59 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
get_functions_within_git_diff,
2525
)
2626
from codeflash.either import is_successful
27-
from codeflash.lsp.server import CodeflashLanguageServer, CodeflashLanguageServerProtocol
27+
from codeflash.lsp.server import CodeflashLanguageServer
2828

2929
if TYPE_CHECKING:
3030
from argparse import Namespace
@@ -50,6 +50,13 @@ class ProvideApiKeyParams:
5050
api_key: str
5151

5252

53+
@dataclass
54+
class ValidateProjectParams:
55+
root_path_abs: str
56+
config_file: Optional[str] = None
57+
skip_validation: bool = False
58+
59+
5360
@dataclass
5461
class OnPatchAppliedParams:
5562
patch_id: str
@@ -60,7 +67,8 @@ class OptimizableFunctionsInCommitParams:
6067
commit_hash: str
6168

6269

63-
server = CodeflashLanguageServer("codeflash-language-server", "v1.0", protocol_cls=CodeflashLanguageServerProtocol)
70+
# server = CodeflashLanguageServer("codeflash-language-server", "v1.0", protocol_cls=CodeflashLanguageServerProtocol)
71+
server = CodeflashLanguageServer("codeflash-language-server", "v1.0")
6472

6573

6674
@server.feature("getOptimizableFunctionsInCurrentDiff")
@@ -160,17 +168,60 @@ def initialize_function_optimization(
160168
return {"functionName": params.functionName, "status": "success"}
161169

162170

163-
@server.feature("validateProject")
164-
def validate_project(server: CodeflashLanguageServer, _params: FunctionOptimizationParams) -> dict[str, str]:
171+
def _find_pyproject_toml(workspace_path: str) -> Path | None:
172+
workspace_path_obj = Path(workspace_path)
173+
max_depth = 2
174+
base_depth = len(workspace_path_obj.parts)
175+
176+
for root, dirs, files in os.walk(workspace_path_obj):
177+
depth = len(Path(root).parts) - base_depth
178+
if depth > max_depth:
179+
# stop going deeper into this branch
180+
dirs.clear()
181+
continue
182+
183+
if "pyproject.toml" in files:
184+
file_path = Path(root) / "pyproject.toml"
185+
with file_path.open("r", encoding="utf-8", errors="ignore") as f:
186+
for line in f:
187+
if line.strip() == "[tool.codeflash]":
188+
return file_path.resolve()
189+
return None
190+
191+
192+
# should be called the first thing to initialize and validate the project
193+
@server.feature("initProject")
194+
def init_project(server: CodeflashLanguageServer, params: ValidateProjectParams) -> dict[str, str]:
165195
from codeflash.cli_cmds.cmd_init import is_valid_pyproject_toml
166196

197+
pyproject_toml_path: Path | None = getattr(params, "config_file", None)
198+
199+
if server.args is None:
200+
if pyproject_toml_path is not None:
201+
# if there is a config file provided use it
202+
server.prepare_optimizer_arguments(pyproject_toml_path)
203+
else:
204+
# otherwise look for it
205+
pyproject_toml_path = _find_pyproject_toml(params.root_path_abs)
206+
server.show_message_log(f"Found pyproject.toml at: {pyproject_toml_path}", "Info")
207+
if pyproject_toml_path:
208+
server.prepare_optimizer_arguments(pyproject_toml_path)
209+
else:
210+
return {
211+
"status": "error",
212+
"message": "No pyproject.toml found in workspace.",
213+
} # TODO: enhancec this message to say there is not tool.codeflash in pyproject.toml or smth
214+
215+
if getattr(params, "skip_validation", False):
216+
return {"status": "success", "moduleRoot": server.args.module_root, "pyprojectPath": pyproject_toml_path}
217+
167218
server.show_message_log("Validating project...", "Info")
168-
config = is_valid_pyproject_toml(server.args.config_file)
219+
config = is_valid_pyproject_toml(pyproject_toml_path)
169220
if config is None:
170221
server.show_message_log("pyproject.toml is not valid", "Error")
171222
return {
172223
"status": "error",
173-
"message": "pyproject.toml is not valid", # keep the error message the same, the extension is matching "pyproject.toml" in the error message to show the codeflash init instructions
224+
"message": "pyproject.toml is not valid", # keep the error message the same, the extension is matching "pyproject.toml" in the error message to show the codeflash init instructions,
174225
}
175226

176227
args = process_args(server)
@@ -183,7 +234,7 @@ def validate_project(server: CodeflashLanguageServer, _params: FunctionOptimizat
183234
except Exception:
184235
return {"status": "error", "message": "Repository has no commits (unborn HEAD)"}
185236

186-
return {"status": "success", "moduleRoot": args.module_root}
237+
return {"status": "success", "moduleRoot": args.module_root, "pyprojectPath": pyproject_toml_path}
187238

188239

189240
def _initialize_optimizer_if_api_key_is_valid(
@@ -328,7 +379,7 @@ def perform_function_optimization( # noqa: PLR0911
328379

329380
devnull_writer = open(os.devnull, "w") # noqa
330381
with contextlib.redirect_stdout(devnull_writer):
331-
function_to_tests, num_discovered_tests = server.optimizer.discover_tests(optimizable_funcs)
382+
function_to_tests, _num_discovered_tests = server.optimizer.discover_tests(optimizable_funcs)
332383
function_optimizer.function_to_tests = function_to_tests
333384

334385
test_setup_result = function_optimizer.generate_and_instrument_tests(

codeflash/lsp/lsp_logger.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def setup_logging() -> logging.Logger:
124124
logger = logging.getLogger()
125125
logger.handlers.clear()
126126

127-
# Set up stderr handler for VS Code output channel with [LSP-Server] prefix
127+
# Set up stderr handler for VS Code output channel
128128
handler = logging.StreamHandler(sys.stderr)
129129
handler.setLevel(logging.DEBUG)
130130

codeflash/lsp/server.py

Lines changed: 3 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,20 @@
11
from __future__ import annotations
22

3-
from pathlib import Path
43
from typing import TYPE_CHECKING, Any
54

6-
from lsprotocol.types import INITIALIZE, LogMessageParams, MessageType
7-
from pygls import uris
8-
from pygls.protocol import LanguageServerProtocol, lsp_method
5+
from lsprotocol.types import LogMessageParams, MessageType
6+
from pygls.protocol import LanguageServerProtocol
97
from pygls.server import LanguageServer
108

119
if TYPE_CHECKING:
12-
from lsprotocol.types import InitializeParams, InitializeResult
10+
from pathlib import Path
1311

1412
from codeflash.optimization.optimizer import Optimizer
1513

1614

1715
class CodeflashLanguageServerProtocol(LanguageServerProtocol):
1816
_server: CodeflashLanguageServer
1917

20-
@lsp_method(INITIALIZE)
21-
def lsp_initialize(self, params: InitializeParams) -> InitializeResult:
22-
server = self._server
23-
initialize_result: InitializeResult = super().lsp_initialize(params)
24-
25-
workspace_uri = params.root_uri
26-
if workspace_uri:
27-
workspace_path = uris.to_fs_path(workspace_uri)
28-
pyproject_toml_path = self._find_pyproject_toml(workspace_path)
29-
if pyproject_toml_path:
30-
server.prepare_optimizer_arguments(pyproject_toml_path)
31-
else:
32-
server.show_message("No pyproject.toml found in workspace.")
33-
else:
34-
server.show_message("No workspace URI provided.")
35-
36-
return initialize_result
37-
38-
def _find_pyproject_toml(self, workspace_path: str) -> Path | None:
39-
workspace_path_obj = Path(workspace_path)
40-
for file_path in workspace_path_obj.rglob("pyproject.toml"):
41-
return file_path.resolve()
42-
return None
43-
4418

4519
class CodeflashLanguageServer(LanguageServer):
4620
def __init__(self, *args: Any, **kwargs: Any) -> None: # noqa: ANN401

0 commit comments

Comments
 (0)