Skip to content

Commit 4efbbd0

Browse files
author
Codeflash Bot
committed
vscode setup info
1 parent 71f8076 commit 4efbbd0

File tree

2 files changed

+66
-41
lines changed

2 files changed

+66
-41
lines changed

codeflash/cli_cmds/cmd_init.py

Lines changed: 59 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,14 @@ class SetupInfo:
6464
enable_telemetry: bool
6565

6666

67+
@dataclass(frozen=True)
68+
class VsCodeSetupInfo:
69+
module_root: str
70+
tests_root: str
71+
test_framework: str
72+
formatter: Union[str, list[str]]
73+
74+
6775
class DependencyManager(Enum):
6876
PIP = auto()
6977
POETRY = auto()
@@ -225,6 +233,10 @@ class CommonSections(Enum):
225233
module_root = "module_root"
226234
tests_root = "tests_root"
227235
test_framework = "test_framework"
236+
formatter_cmds = "formatter_cmds"
237+
238+
def get_toml_key(self) -> str:
239+
return self.value.replace("_", "-")
228240

229241

230242
@lru_cache(maxsize=1)
@@ -256,6 +268,8 @@ def get_suggestions(section: str) -> tuple(list[str], Optional[str]):
256268
if section == CommonSections.test_framework:
257269
auto_detected = detect_test_framework_from_config_files(Path.cwd())
258270
return ["pytest", "unittest"], auto_detected
271+
if section == CommonSections.formatter_cmds:
272+
return ["disabled", "ruff", "black"], "disabled"
259273
msg = f"Unknown section: {section}"
260274
raise ValueError(msg)
261275

@@ -973,8 +987,24 @@ def customize_codeflash_yaml_content(
973987
return optimize_yml_content.replace("{{ codeflash_command }}", codeflash_cmd)
974988

975989

990+
def get_formatter_cmds(formatter: str) -> list[str]:
991+
if formatter == "black":
992+
return ["black $file"]
993+
if formatter == "ruff":
994+
return ["ruff check --exit-zero --fix $file", "ruff format $file"]
995+
if formatter == "other":
996+
click.echo(
997+
"🔧 In pyproject.toml, please replace 'your-formatter' with the command you use to format your code."
998+
)
999+
return ["your-formatter $file"]
1000+
if formatter in {"don't use a formatter", "disabled"}:
1001+
return ["disabled"]
1002+
return [formatter]
1003+
1004+
9761005
# Create or update the pyproject.toml file with the Codeflash dependency & configuration
977-
def configure_pyproject_toml(setup_info: SetupInfo, config_file: Optional[Path] = None) -> bool:
1006+
def configure_pyproject_toml(setup_info: Union[VsCodeSetupInfo, SetupInfo], config_file: Optional[Path] = None) -> bool:
1007+
for_vscode = isinstance(setup_info, VsCodeSetupInfo)
9781008
toml_path = config_file or Path.cwd() / "pyproject.toml"
9791009
try:
9801010
with toml_path.open(encoding="utf8") as pyproject_file:
@@ -988,36 +1018,40 @@ def configure_pyproject_toml(setup_info: SetupInfo, config_file: Optional[Path]
9881018

9891019
codeflash_section = tomlkit.table()
9901020
codeflash_section.add(tomlkit.comment("All paths are relative to this pyproject.toml's directory."))
991-
codeflash_section["module-root"] = setup_info.module_root
992-
codeflash_section["tests-root"] = setup_info.tests_root
993-
codeflash_section["test-framework"] = setup_info.test_framework
994-
codeflash_section["ignore-paths"] = setup_info.ignore_paths
995-
if not setup_info.enable_telemetry:
996-
codeflash_section["disable-telemetry"] = not setup_info.enable_telemetry
997-
if setup_info.git_remote not in ["", "origin"]:
998-
codeflash_section["git-remote"] = setup_info.git_remote
1021+
1022+
if for_vscode:
1023+
for section in CommonSections:
1024+
if hasattr(setup_info, section.value):
1025+
codeflash_section[section.get_toml_key()] = getattr(setup_info, section.value)
1026+
else:
1027+
codeflash_section["module-root"] = setup_info.module_root
1028+
codeflash_section["tests-root"] = setup_info.tests_root
1029+
codeflash_section["test-framework"] = setup_info.test_framework
1030+
codeflash_section["ignore-paths"] = setup_info.ignore_paths
1031+
if not setup_info.enable_telemetry:
1032+
codeflash_section["disable-telemetry"] = not setup_info.enable_telemetry
1033+
if setup_info.git_remote not in ["", "origin"]:
1034+
codeflash_section["git-remote"] = setup_info.git_remote
1035+
9991036
formatter = setup_info.formatter
1000-
formatter_cmds = []
1001-
1002-
if isinstance(formatter, list):
1003-
formatter_cmds = formatter
1004-
elif formatter == "black":
1005-
formatter_cmds.append("black $file")
1006-
elif formatter == "ruff":
1007-
formatter_cmds.extend(["ruff check --exit-zero --fix $file", "ruff format $file"])
1008-
elif formatter == "other":
1009-
formatter_cmds.append("your-formatter $file")
1010-
click.echo(
1011-
"🔧 In pyproject.toml, please replace 'your-formatter' with the command you use to format your code."
1012-
)
1013-
elif formatter == "don't use a formatter":
1014-
formatter_cmds.append("disabled")
1037+
1038+
formatter_cmds = formatter if isinstance(formatter, list) else get_formatter_cmds(formatter)
10151039

10161040
check_formatter_installed(formatter_cmds, exit_on_failure=False)
10171041
codeflash_section["formatter-cmds"] = formatter_cmds
10181042
# Add the 'codeflash' section, ensuring 'tool' section exists
10191043
tool_section = pyproject_data.get("tool", tomlkit.table())
1020-
tool_section["codeflash"] = codeflash_section
1044+
1045+
if for_vscode:
1046+
# merge the existing codeflash section, instead of overwriting it
1047+
existing_codeflash = tool_section.get("codeflash", tomlkit.table())
1048+
1049+
for key, value in codeflash_section.items():
1050+
existing_codeflash[key] = value
1051+
tool_section["codeflash"] = existing_codeflash
1052+
else:
1053+
tool_section["codeflash"] = codeflash_section
1054+
10211055
pyproject_data["tool"] = tool_section
10221056

10231057
with toml_path.open("w", encoding="utf8") as pyproject_file:

codeflash/lsp/beta.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@
1313
from codeflash.cli_cmds.cli import process_pyproject_config
1414
from codeflash.cli_cmds.cmd_init import (
1515
CommonSections,
16-
SetupInfo,
16+
VsCodeSetupInfo,
1717
configure_pyproject_toml,
1818
create_empty_pyproject_toml,
19+
get_formatter_cmds,
1920
get_suggestions,
2021
get_valid_subdirs,
2122
is_valid_pyproject_toml,
2223
)
23-
from codeflash.code_utils.config_parser import parse_config_file
2424
from codeflash.code_utils.git_utils import git_root_dir
2525
from codeflash.code_utils.shell_utils import save_api_key_to_rc
2626
from codeflash.discovery.functions_to_optimize import (
@@ -174,28 +174,17 @@ def write_config(_server: CodeflashLanguageServer, params: WriteConfigParams) ->
174174
cfg = params.config
175175
cfg_file = Path(params.config_file) if params.config_file else None
176176

177-
parsed_config = {}
178-
179177
if cfg_file and not cfg_file.exists():
180178
# the client provided a config path but it doesn't exist
181179
create_empty_pyproject_toml(cfg_file)
182-
elif cfg_file and cfg_file.exists():
183-
try:
184-
parsed_config, _ = parse_config_file(cfg_file)
185-
except Exception as e:
186-
return {"status": "error", "message": f"Failed to parse configuration: {e}"}
187180

188-
setup_info = SetupInfo(
181+
setup_info = VsCodeSetupInfo(
189182
module_root=getattr(cfg, "module_root", ""),
190183
tests_root=getattr(cfg, "tests_root", ""),
191184
test_framework=getattr(cfg, "test_framework", "pytest"),
192-
# keep other stuff as it is
193-
benchmarks_root=None, # we don't support benchmarks in the LSP
194-
ignore_paths=parsed_config.get("ignore_paths", []),
195-
formatter=parsed_config.get("formatter_cmds", ["disabled"]),
196-
git_remote=parsed_config.get("git_remote", ""),
197-
enable_telemetry=parsed_config.get("disable_telemetry", True),
185+
formatter=get_formatter_cmds(getattr(cfg, "formatter_cmds", "disabled")),
198186
)
187+
199188
devnull_writer = open(os.devnull, "w") # noqa
200189
with contextlib.redirect_stdout(devnull_writer):
201190
configured = configure_pyproject_toml(setup_info, cfg_file)
@@ -209,11 +198,13 @@ def get_config_suggestions(_server: CodeflashLanguageServer, _params: any) -> di
209198
module_root_suggestions, default_module_root = get_suggestions(CommonSections.module_root)
210199
tests_root_suggestions, default_tests_root = get_suggestions(CommonSections.tests_root)
211200
test_framework_suggestions, default_test_framework = get_suggestions(CommonSections.test_framework)
201+
formatter_suggestions, default_formatter = get_suggestions(CommonSections.formatter_cmds)
212202
get_valid_subdirs.cache_clear()
213203
return {
214204
"module_root": {"choices": module_root_suggestions, "default": default_module_root},
215205
"tests_root": {"choices": tests_root_suggestions, "default": default_tests_root},
216206
"test_framework": {"choices": test_framework_suggestions, "default": default_test_framework},
207+
"formatter_cmds": {"choices": formatter_suggestions, "default": default_formatter},
217208
}
218209

219210

0 commit comments

Comments
 (0)