@@ -64,6 +64,14 @@ class SetupInfo:
6464 enable_telemetry : bool
6565
6666
67+ @dataclass (frozen = True )
68+ class VsCodeSetupInfo :
69+ module_root : str
70+ tests_root : str
71+ test_framework : str
72+ formatter : Union [str , list [str ]]
73+
74+
6775class DependencyManager (Enum ):
6876 PIP = auto ()
6977 POETRY = auto ()
@@ -225,6 +233,10 @@ class CommonSections(Enum):
225233 module_root = "module_root"
226234 tests_root = "tests_root"
227235 test_framework = "test_framework"
236+ formatter_cmds = "formatter_cmds"
237+
238+ def get_toml_key (self ) -> str :
239+ return self .value .replace ("_" , "-" )
228240
229241
230242@lru_cache (maxsize = 1 )
@@ -256,6 +268,8 @@ def get_suggestions(section: str) -> tuple(list[str], Optional[str]):
256268 if section == CommonSections .test_framework :
257269 auto_detected = detect_test_framework_from_config_files (Path .cwd ())
258270 return ["pytest" , "unittest" ], auto_detected
271+ if section == CommonSections .formatter_cmds :
272+ return ["disabled" , "ruff" , "black" ], "disabled"
259273 msg = f"Unknown section: { section } "
260274 raise ValueError (msg )
261275
@@ -973,8 +987,24 @@ def customize_codeflash_yaml_content(
973987 return optimize_yml_content .replace ("{{ codeflash_command }}" , codeflash_cmd )
974988
975989
990+ def get_formatter_cmds (formatter : str ) -> list [str ]:
991+ if formatter == "black" :
992+ return ["black $file" ]
993+ if formatter == "ruff" :
994+ return ["ruff check --exit-zero --fix $file" , "ruff format $file" ]
995+ if formatter == "other" :
996+ click .echo (
997+ "🔧 In pyproject.toml, please replace 'your-formatter' with the command you use to format your code."
998+ )
999+ return ["your-formatter $file" ]
1000+ if formatter in {"don't use a formatter" , "disabled" }:
1001+ return ["disabled" ]
1002+ return [formatter ]
1003+
1004+
9761005# Create or update the pyproject.toml file with the Codeflash dependency & configuration
977- def configure_pyproject_toml (setup_info : SetupInfo , config_file : Optional [Path ] = None ) -> bool :
1006+ def configure_pyproject_toml (setup_info : Union [VsCodeSetupInfo , SetupInfo ], config_file : Optional [Path ] = None ) -> bool :
1007+ for_vscode = isinstance (setup_info , VsCodeSetupInfo )
9781008 toml_path = config_file or Path .cwd () / "pyproject.toml"
9791009 try :
9801010 with toml_path .open (encoding = "utf8" ) as pyproject_file :
@@ -988,36 +1018,40 @@ def configure_pyproject_toml(setup_info: SetupInfo, config_file: Optional[Path]
9881018
9891019 codeflash_section = tomlkit .table ()
9901020 codeflash_section .add (tomlkit .comment ("All paths are relative to this pyproject.toml's directory." ))
991- codeflash_section ["module-root" ] = setup_info .module_root
992- codeflash_section ["tests-root" ] = setup_info .tests_root
993- codeflash_section ["test-framework" ] = setup_info .test_framework
994- codeflash_section ["ignore-paths" ] = setup_info .ignore_paths
995- if not setup_info .enable_telemetry :
996- codeflash_section ["disable-telemetry" ] = not setup_info .enable_telemetry
997- if setup_info .git_remote not in ["" , "origin" ]:
998- codeflash_section ["git-remote" ] = setup_info .git_remote
1021+
1022+ if for_vscode :
1023+ for section in CommonSections :
1024+ if hasattr (setup_info , section .value ):
1025+ codeflash_section [section .get_toml_key ()] = getattr (setup_info , section .value )
1026+ else :
1027+ codeflash_section ["module-root" ] = setup_info .module_root
1028+ codeflash_section ["tests-root" ] = setup_info .tests_root
1029+ codeflash_section ["test-framework" ] = setup_info .test_framework
1030+ codeflash_section ["ignore-paths" ] = setup_info .ignore_paths
1031+ if not setup_info .enable_telemetry :
1032+ codeflash_section ["disable-telemetry" ] = not setup_info .enable_telemetry
1033+ if setup_info .git_remote not in ["" , "origin" ]:
1034+ codeflash_section ["git-remote" ] = setup_info .git_remote
1035+
9991036 formatter = setup_info .formatter
1000- formatter_cmds = []
1001-
1002- if isinstance (formatter , list ):
1003- formatter_cmds = formatter
1004- elif formatter == "black" :
1005- formatter_cmds .append ("black $file" )
1006- elif formatter == "ruff" :
1007- formatter_cmds .extend (["ruff check --exit-zero --fix $file" , "ruff format $file" ])
1008- elif formatter == "other" :
1009- formatter_cmds .append ("your-formatter $file" )
1010- click .echo (
1011- "🔧 In pyproject.toml, please replace 'your-formatter' with the command you use to format your code."
1012- )
1013- elif formatter == "don't use a formatter" :
1014- formatter_cmds .append ("disabled" )
1037+
1038+ formatter_cmds = formatter if isinstance (formatter , list ) else get_formatter_cmds (formatter )
10151039
10161040 check_formatter_installed (formatter_cmds , exit_on_failure = False )
10171041 codeflash_section ["formatter-cmds" ] = formatter_cmds
10181042 # Add the 'codeflash' section, ensuring 'tool' section exists
10191043 tool_section = pyproject_data .get ("tool" , tomlkit .table ())
1020- tool_section ["codeflash" ] = codeflash_section
1044+
1045+ if for_vscode :
1046+ # merge the existing codeflash section, instead of overwriting it
1047+ existing_codeflash = tool_section .get ("codeflash" , tomlkit .table ())
1048+
1049+ for key , value in codeflash_section .items ():
1050+ existing_codeflash [key ] = value
1051+ tool_section ["codeflash" ] = existing_codeflash
1052+ else :
1053+ tool_section ["codeflash" ] = codeflash_section
1054+
10211055 pyproject_data ["tool" ] = tool_section
10221056
10231057 with toml_path .open ("w" , encoding = "utf8" ) as pyproject_file :
0 commit comments