2424 get_functions_within_git_diff ,
2525)
2626from codeflash .either import is_successful
27- from codeflash .lsp .server import CodeflashLanguageServer , CodeflashLanguageServerProtocol
27+ from codeflash .lsp .server import CodeflashLanguageServer
2828
2929if TYPE_CHECKING :
3030 from argparse import Namespace
@@ -50,6 +50,13 @@ class ProvideApiKeyParams:
5050 api_key : str
5151
5252
53+ @dataclass
54+ class ValidateProjectParams :
55+ root_path_abs : str
56+ config_file : Optional [str ] = None
57+ skip_validation : bool = False
58+
59+
5360@dataclass
5461class OnPatchAppliedParams :
5562 patch_id : str
@@ -60,7 +67,8 @@ class OptimizableFunctionsInCommitParams:
6067 commit_hash : str
6168
6269
63- server = CodeflashLanguageServer ("codeflash-language-server" , "v1.0" , protocol_cls = CodeflashLanguageServerProtocol )
70+ # server = CodeflashLanguageServer("codeflash-language-server", "v1.0", protocol_cls=CodeflashLanguageServerProtocol)
71+ server = CodeflashLanguageServer ("codeflash-language-server" , "v1.0" )
6472
6573
6674@server .feature ("getOptimizableFunctionsInCurrentDiff" )
@@ -160,17 +168,60 @@ def initialize_function_optimization(
160168 return {"functionName" : params .functionName , "status" : "success" }
161169
162170
163- @server .feature ("validateProject" )
164- def validate_project (server : CodeflashLanguageServer , _params : FunctionOptimizationParams ) -> dict [str , str ]:
171+ def _find_pyproject_toml (workspace_path : str ) -> Path | None :
172+ workspace_path_obj = Path (workspace_path )
173+ max_depth = 2
174+ base_depth = len (workspace_path_obj .parts )
175+
176+ for root , dirs , files in os .walk (workspace_path_obj ):
177+ depth = len (Path (root ).parts ) - base_depth
178+ if depth > max_depth :
179+ # stop going deeper into this branch
180+ dirs .clear ()
181+ continue
182+
183+ if "pyproject.toml" in files :
184+ file_path = Path (root ) / "pyproject.toml"
185+ with file_path .open ("r" , encoding = "utf-8" , errors = "ignore" ) as f :
186+ for line in f :
187+ if line .strip () == "[tool.codeflash]" :
188+ return file_path .resolve ()
189+ return None
190+
191+
192+ # should be called the first thing to initialize and validate the project
193+ @server .feature ("initProject" )
194+ def init_project (server : CodeflashLanguageServer , params : ValidateProjectParams ) -> dict [str , str ]:
165195 from codeflash .cli_cmds .cmd_init import is_valid_pyproject_toml
166196
197+ pyproject_toml_path : Path | None = getattr (params , "config_file" , None )
198+
199+ if server .args is None :
200+ if pyproject_toml_path is not None :
201+ # if there is a config file provided use it
202+ server .prepare_optimizer_arguments (pyproject_toml_path )
203+ else :
204+ # otherwise look for it
205+ pyproject_toml_path = _find_pyproject_toml (params .root_path_abs )
206+ server .show_message_log (f"Found pyproject.toml at: { pyproject_toml_path } " , "Info" )
207+ if pyproject_toml_path :
208+ server .prepare_optimizer_arguments (pyproject_toml_path )
209+ else :
210+ return {
211+ "status" : "error" ,
212+ "message" : "No pyproject.toml found in workspace." ,
213+ } # TODO: enhancec this message to say there is not tool.codeflash in pyproject.toml or smth
214+
215+ if getattr (params , "skip_validation" , False ):
216+ return {"status" : "success" , "moduleRoot" : server .args .module_root , "pyprojectPath" : pyproject_toml_path }
217+
167218 server .show_message_log ("Validating project..." , "Info" )
168- config = is_valid_pyproject_toml (server . args . config_file )
219+ config = is_valid_pyproject_toml (pyproject_toml_path )
169220 if config is None :
170221 server .show_message_log ("pyproject.toml is not valid" , "Error" )
171222 return {
172223 "status" : "error" ,
173- "message" : "pyproject.toml is not valid" , # keep the error message the same, the extension is matching "pyproject.toml" in the error message to show the codeflash init instructions
224+ "message" : "pyproject.toml is not valid" , # keep the error message the same, the extension is matching "pyproject.toml" in the error message to show the codeflash init instructions,
174225 }
175226
176227 args = process_args (server )
@@ -183,7 +234,7 @@ def validate_project(server: CodeflashLanguageServer, _params: FunctionOptimizat
183234 except Exception :
184235 return {"status" : "error" , "message" : "Repository has no commits (unborn HEAD)" }
185236
186- return {"status" : "success" , "moduleRoot" : args .module_root }
237+ return {"status" : "success" , "moduleRoot" : args .module_root , "pyprojectPath" : pyproject_toml_path }
187238
188239
189240def _initialize_optimizer_if_api_key_is_valid (
@@ -328,7 +379,7 @@ def perform_function_optimization( # noqa: PLR0911
328379
329380 devnull_writer = open (os .devnull , "w" ) # noqa
330381 with contextlib .redirect_stdout (devnull_writer ):
331- function_to_tests , num_discovered_tests = server .optimizer .discover_tests (optimizable_funcs )
382+ function_to_tests , _num_discovered_tests = server .optimizer .discover_tests (optimizable_funcs )
332383 function_optimizer .function_to_tests = function_to_tests
333384
334385 test_setup_result = function_optimizer .generate_and_instrument_tests (
0 commit comments