Skip to content

Commit 952bb6d

Browse files
Merge pull request #799 from codeflash-ai/lsp/return-selected-pyproject-toml
[LSP] Return selected pyproject.toml path
2 parents 22cc41b + 003fd6b commit 952bb6d

File tree

1 file changed

+24
-19
lines changed

1 file changed

+24
-19
lines changed

codeflash/lsp/beta.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -131,10 +131,11 @@ def get_optimizable_functions(
131131
return path_to_qualified_names
132132

133133

134-
def _find_pyproject_toml(workspace_path: str) -> Path | None:
134+
def _find_pyproject_toml(workspace_path: str) -> tuple[Path | None, bool]:
135135
workspace_path_obj = Path(workspace_path)
136136
max_depth = 2
137137
base_depth = len(workspace_path_obj.parts)
138+
top_level_pyproject = None
138139

139140
for root, dirs, files in os.walk(workspace_path_obj):
140141
depth = len(Path(root).parts) - base_depth
@@ -145,32 +146,39 @@ def _find_pyproject_toml(workspace_path: str) -> Path | None:
145146

146147
if "pyproject.toml" in files:
147148
file_path = Path(root) / "pyproject.toml"
149+
if depth == 0:
150+
top_level_pyproject = file_path
148151
with file_path.open("r", encoding="utf-8", errors="ignore") as f:
149152
for line in f:
150153
if line.strip() == "[tool.codeflash]":
151-
return file_path.resolve()
152-
return None
154+
return file_path.resolve(), True
155+
return top_level_pyproject, False
153156

154157

155158
# should be called the first thing to initialize and validate the project
156159
@server.feature("initProject")
157-
def init_project(server: CodeflashLanguageServer, params: ValidateProjectParams) -> dict[str, str]:
160+
def init_project(server: CodeflashLanguageServer, params: ValidateProjectParams) -> dict[str, str]: # noqa: PLR0911
158161
from codeflash.cli_cmds.cmd_init import is_valid_pyproject_toml
159162

160-
pyproject_toml_path: Path | None = getattr(params, "config_file", None)
163+
pyproject_toml_path: Path | None = getattr(params, "config_file", None) or getattr(server.args, "config_file", None)
161164

162-
if server.args is None:
163-
if pyproject_toml_path is not None:
164-
# if there is a config file provided use it
165+
if pyproject_toml_path is not None:
166+
# if there is a config file provided use it
167+
server.prepare_optimizer_arguments(pyproject_toml_path)
168+
else:
169+
# otherwise look for it
170+
pyproject_toml_path, has_codeflash_config = _find_pyproject_toml(params.root_path_abs)
171+
if pyproject_toml_path and has_codeflash_config:
172+
server.show_message_log(f"Found pyproject.toml at: {pyproject_toml_path}", "Info")
165173
server.prepare_optimizer_arguments(pyproject_toml_path)
174+
elif pyproject_toml_path and not has_codeflash_config:
175+
return {
176+
"status": "error",
177+
"message": "pyproject.toml found in workspace, but no codeflash config.",
178+
"pyprojectPath": pyproject_toml_path,
179+
}
166180
else:
167-
# otherwise look for it
168-
pyproject_toml_path = _find_pyproject_toml(params.root_path_abs)
169-
server.show_message_log(f"Found pyproject.toml at: {pyproject_toml_path}", "Info")
170-
if pyproject_toml_path:
171-
server.prepare_optimizer_arguments(pyproject_toml_path)
172-
else:
173-
return {"status": "error", "message": "No pyproject.toml found in workspace."}
181+
return {"status": "error", "message": "No pyproject.toml found in workspace."}
174182

175183
# since we are using worktrees, optimization diffs are generated with respect to the root of the repo.
176184
root = str(git_root_dir())
@@ -187,10 +195,7 @@ def init_project(server: CodeflashLanguageServer, params: ValidateProjectParams)
187195
config = is_valid_pyproject_toml(pyproject_toml_path)
188196
if config is None:
189197
server.show_message_log("pyproject.toml is not valid", "Error")
190-
return {
191-
"status": "error",
192-
"message": "pyproject.toml is not valid", # keep the error message the same, the extension is matching "pyproject.toml" in the error message to show the codeflash init instructions,
193-
}
198+
return {"status": "error", "message": "not valid", "pyprojectPath": pyproject_toml_path}
194199

195200
args = process_args(server)
196201
repo = git.Repo(args.module_root, search_parent_directories=True)

0 commit comments

Comments
 (0)