Skip to content

Commit 2c84b05

Browse files
author
Codeflash Bot
committed
read and save configurations
1 parent 7aee1c1 commit 2c84b05

File tree

2 files changed

+96
-27
lines changed

2 files changed

+96
-27
lines changed

codeflash/cli_cmds/cmd_init.py

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,9 @@ class SetupInfo:
5858
benchmarks_root: Union[str, None]
5959
test_framework: str
6060
ignore_paths: list[str]
61-
formatter: str
61+
formatter: Union[str, list[str]]
6262
git_remote: str
63+
enable_telemetry: bool
6364

6465

6566
class DependencyManager(Enum):
@@ -93,7 +94,9 @@ def init_codeflash() -> None:
9394
if should_modify:
9495
setup_info: SetupInfo = collect_setup_info()
9596
git_remote = setup_info.git_remote
96-
configure_pyproject_toml(setup_info)
97+
configured = configure_pyproject_toml(setup_info)
98+
if not configured:
99+
apologize_and_exit()
97100

98101
install_github_app(git_remote)
99102

@@ -156,30 +159,30 @@ def ask_run_end_to_end_test(args: Namespace) -> None:
156159
run_end_to_end_test(args, bubble_sort_path, bubble_sort_test_path)
157160

158161

159-
def is_valid_pyproject_toml(pyproject_toml_path: Path) -> tuple[dict[str, Any] | None, str]: # noqa: PLR0911
162+
def is_valid_pyproject_toml(pyproject_toml_path: Path) -> tuple[bool, dict[str, Any] | None, str]: # noqa: PLR0911
160163
if not pyproject_toml_path.exists():
161-
return None, f"Configuration file not found: {pyproject_toml_path}"
164+
return False, None, f"Configuration file not found: {pyproject_toml_path}"
162165

163166
try:
164167
config, _ = parse_config_file(pyproject_toml_path)
165168
except Exception as e:
166-
return None, f"Failed to parse configuration: {e}"
169+
return False, None, f"Failed to parse configuration: {e}"
167170

168171
module_root = config.get("module_root")
169172
if not module_root:
170-
return None, "Missing required field: 'module_root'"
173+
return False, config, "Missing required field: 'module_root'"
171174

172175
if not Path(module_root).is_dir():
173-
return None, f"Invalid 'module_root': directory does not exist at {module_root}"
176+
return False, config, f"Invalid 'module_root': directory does not exist at {module_root}"
174177

175178
tests_root = config.get("tests_root")
176179
if not tests_root:
177-
return None, "Missing required field: 'tests_root'"
180+
return False, config, "Missing required field: 'tests_root'"
178181

179182
if not Path(tests_root).is_dir():
180-
return None, f"Invalid 'tests_root': directory does not exist at {tests_root}"
183+
return False, config, f"Invalid 'tests_root': directory does not exist at {tests_root}"
181184

182-
return config, ""
185+
return True, config, ""
183186

184187

185188
def should_modify_pyproject_toml() -> tuple[bool, dict[str, Any] | None]:
@@ -191,8 +194,9 @@ def should_modify_pyproject_toml() -> tuple[bool, dict[str, Any] | None]:
191194

192195
pyproject_toml_path = Path.cwd() / "pyproject.toml"
193196

194-
config, _message = is_valid_pyproject_toml(pyproject_toml_path)
195-
if config is None:
197+
valid, config, _message = is_valid_pyproject_toml(pyproject_toml_path)
198+
if not valid:
199+
# needs to be re-configured
196200
return True, None
197201

198202
return Confirm.ask(
@@ -532,6 +536,8 @@ def collect_setup_info() -> SetupInfo:
532536
except InvalidGitRepositoryError:
533537
git_remote = ""
534538

539+
enable_telemetry = ask_for_telemetry()
540+
535541
ignore_paths: list[str] = []
536542
return SetupInfo(
537543
module_root=str(module_root),
@@ -541,6 +547,7 @@ def collect_setup_info() -> SetupInfo:
541547
ignore_paths=ignore_paths,
542548
formatter=cast("str", formatter),
543549
git_remote=str(git_remote),
550+
enable_telemetry=enable_telemetry,
544551
)
545552

546553

@@ -966,8 +973,8 @@ def customize_codeflash_yaml_content(
966973

967974

968975
# Create or update the pyproject.toml file with the Codeflash dependency & configuration
969-
def configure_pyproject_toml(setup_info: SetupInfo) -> None:
970-
toml_path = Path.cwd() / "pyproject.toml"
976+
def configure_pyproject_toml(setup_info: SetupInfo, config_file: Optional[Path] = None) -> bool:
977+
toml_path = config_file or Path.cwd() / "pyproject.toml"
971978
try:
972979
with toml_path.open(encoding="utf8") as pyproject_file:
973980
pyproject_data = tomlkit.parse(pyproject_file.read())
@@ -976,23 +983,24 @@ def configure_pyproject_toml(setup_info: SetupInfo) -> None:
976983
f"I couldn't find a pyproject.toml in the current directory.{LF}"
977984
f"Please create a new empty pyproject.toml file here, OR if you use poetry then run `poetry init`, OR run `codeflash init` again from a directory with an existing pyproject.toml file."
978985
)
979-
apologize_and_exit()
980-
981-
enable_telemetry = ask_for_telemetry()
986+
return False
982987

983988
codeflash_section = tomlkit.table()
984989
codeflash_section.add(tomlkit.comment("All paths are relative to this pyproject.toml's directory."))
985990
codeflash_section["module-root"] = setup_info.module_root
986991
codeflash_section["tests-root"] = setup_info.tests_root
987992
codeflash_section["test-framework"] = setup_info.test_framework
988993
codeflash_section["ignore-paths"] = setup_info.ignore_paths
989-
if not enable_telemetry:
990-
codeflash_section["disable-telemetry"] = not enable_telemetry
994+
if not setup_info.enable_telemetry:
995+
codeflash_section["disable-telemetry"] = not setup_info.enable_telemetry
991996
if setup_info.git_remote not in ["", "origin"]:
992997
codeflash_section["git-remote"] = setup_info.git_remote
993998
formatter = setup_info.formatter
994999
formatter_cmds = []
995-
if formatter == "black":
1000+
1001+
if isinstance(formatter, list):
1002+
formatter_cmds = formatter
1003+
elif formatter == "black":
9961004
formatter_cmds.append("black $file")
9971005
elif formatter == "ruff":
9981006
formatter_cmds.extend(["ruff check --exit-zero --fix $file", "ruff format $file"])
@@ -1003,6 +1011,7 @@ def configure_pyproject_toml(setup_info: SetupInfo) -> None:
10031011
)
10041012
elif formatter == "don't use a formatter":
10051013
formatter_cmds.append("disabled")
1014+
10061015
check_formatter_installed(formatter_cmds, exit_on_failure=False)
10071016
codeflash_section["formatter-cmds"] = formatter_cmds
10081017
# Add the 'codeflash' section, ensuring 'tool' section exists
@@ -1014,6 +1023,7 @@ def configure_pyproject_toml(setup_info: SetupInfo) -> None:
10141023
pyproject_file.write(tomlkit.dumps(pyproject_data))
10151024
click.echo(f"✅ Added Codeflash configuration to {toml_path}")
10161025
click.echo()
1026+
return True
10171027

10181028

10191029
def install_github_app(git_remote: str) -> None:

codeflash/lsp/beta.py

Lines changed: 66 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import asyncio
4+
import contextlib
45
import os
56
from dataclasses import dataclass
67
from pathlib import Path
@@ -10,6 +11,15 @@
1011

1112
from codeflash.api.cfapi import get_codeflash_api_key, get_user_id
1213
from codeflash.cli_cmds.cli import process_pyproject_config
14+
from codeflash.cli_cmds.cmd_init import (
15+
CommonSections,
16+
SetupInfo,
17+
configure_pyproject_toml,
18+
get_suggestions,
19+
get_valid_subdirs,
20+
is_valid_pyproject_toml,
21+
)
22+
from codeflash.code_utils.config_parser import parse_config_file
1323
from codeflash.code_utils.git_utils import git_root_dir
1424
from codeflash.code_utils.shell_utils import save_api_key_to_rc
1525
from codeflash.discovery.functions_to_optimize import (
@@ -69,6 +79,12 @@ class OptimizableFunctionsInCommitParams:
6979
commit_hash: str
7080

7181

82+
@dataclass
83+
class WriteConfigParams:
84+
config_file: str
85+
config: any
86+
87+
7288
server = CodeflashLanguageServer("codeflash-language-server", "v1.0")
7389

7490

@@ -152,11 +168,51 @@ def _find_pyproject_toml(workspace_path: str) -> tuple[Path | None, bool]:
152168
return top_level_pyproject, False
153169

154170

171+
@server.feature("writeConfig")
172+
def write_config(_server: CodeflashLanguageServer, params: WriteConfigParams) -> dict[str, any]:
173+
cfg = params.config
174+
cfg_file = Path(params.config_file)
175+
176+
try:
177+
parsed_config, _ = parse_config_file(cfg_file)
178+
except Exception as e:
179+
return {"status": "error", "message": f"Failed to parse configuration: {e}"}
180+
_server.show_message_log(f"{parsed_config}", "Info")
181+
setup_info = SetupInfo(
182+
module_root=getattr(cfg, "module_root", ""),
183+
tests_root=getattr(cfg, "tests_root", ""),
184+
test_framework=getattr(cfg, "test_framework", "pytest"),
185+
# keep other stuff as it is
186+
benchmarks_root=None, # we don't support benchmarks in the LSP
187+
ignore_paths=parsed_config.get("ignore_paths", []),
188+
formatter=parsed_config.get("formatter_cmds", ["disabled"]),
189+
git_remote=parsed_config.get("git_remote", ""),
190+
enable_telemetry=parsed_config.get("disable_telemetry", True),
191+
)
192+
devnull_writer = open(os.devnull, "w") # noqa
193+
with contextlib.redirect_stdout(devnull_writer):
194+
configured = configure_pyproject_toml(setup_info, cfg_file)
195+
if configured:
196+
return {"status": "success"}
197+
return {"status": "error", "message": "Failed to configure pyproject.toml"}
198+
199+
200+
@server.feature("getConfigSuggestions")
201+
def get_config_suggestions(_server: CodeflashLanguageServer, _params: any) -> dict[str, any]:
202+
module_root_suggestions, default_module_root = get_suggestions(CommonSections.module_root)
203+
tests_root_suggestions, default_tests_root = get_suggestions(CommonSections.tests_root)
204+
test_framework_suggestions, default_test_framework = get_suggestions(CommonSections.test_framework)
205+
get_valid_subdirs.cache_clear()
206+
return {
207+
"module_root": {"choices": module_root_suggestions, "default": default_module_root},
208+
"tests_root": {"choices": tests_root_suggestions, "default": default_tests_root},
209+
"test_framework": {"choices": test_framework_suggestions, "default": default_test_framework},
210+
}
211+
212+
155213
# should be called the first thing to initialize and validate the project
156214
@server.feature("initProject")
157215
def init_project(server: CodeflashLanguageServer, params: ValidateProjectParams) -> dict[str, str]:
158-
from codeflash.cli_cmds.cmd_init import is_valid_pyproject_toml
159-
160216
# Always process args in the init project, the extension can call
161217
server.args_processed_before = False
162218

@@ -190,11 +246,14 @@ def init_project(server: CodeflashLanguageServer, params: ValidateProjectParams)
190246
"root": root,
191247
}
192248

193-
server.show_message_log("Validating project...", "Info")
194-
config, reason = is_valid_pyproject_toml(pyproject_toml_path)
195-
if config is None:
196-
server.show_message_log("pyproject.toml is not valid", "Error")
197-
return {"status": "error", "message": f"reason: {reason}", "pyprojectPath": pyproject_toml_path}
249+
valid, config, reason = is_valid_pyproject_toml(pyproject_toml_path)
250+
if not valid:
251+
return {
252+
"status": "error",
253+
"message": f"reason: {reason}",
254+
"pyprojectPath": pyproject_toml_path,
255+
"existingConfig": config,
256+
}
198257

199258
args = process_args(server)
200259

0 commit comments

Comments
 (0)