@@ -131,10 +131,11 @@ def get_optimizable_functions(
131131 return path_to_qualified_names
132132
133133
134- def _find_pyproject_toml (workspace_path : str ) -> Path | None :
134+ def _find_pyproject_toml (workspace_path : str ) -> tuple [ Path | None , bool ] :
135135 workspace_path_obj = Path (workspace_path )
136136 max_depth = 2
137137 base_depth = len (workspace_path_obj .parts )
138+ top_level_pyproject = None
138139
139140 for root , dirs , files in os .walk (workspace_path_obj ):
140141 depth = len (Path (root ).parts ) - base_depth
@@ -145,32 +146,39 @@ def _find_pyproject_toml(workspace_path: str) -> Path | None:
145146
146147 if "pyproject.toml" in files :
147148 file_path = Path (root ) / "pyproject.toml"
149+ if depth == 0 :
150+ top_level_pyproject = file_path
148151 with file_path .open ("r" , encoding = "utf-8" , errors = "ignore" ) as f :
149152 for line in f :
150153 if line .strip () == "[tool.codeflash]" :
151- return file_path .resolve ()
152- return None
154+ return file_path .resolve (), True
155+ return top_level_pyproject , False
153156
154157
155158# should be called the first thing to initialize and validate the project
156159@server .feature ("initProject" )
157- def init_project (server : CodeflashLanguageServer , params : ValidateProjectParams ) -> dict [str , str ]:
160+ def init_project (server : CodeflashLanguageServer , params : ValidateProjectParams ) -> dict [str , str ]: # noqa: PLR0911
158161 from codeflash .cli_cmds .cmd_init import is_valid_pyproject_toml
159162
160- pyproject_toml_path : Path | None = getattr (params , "config_file" , None )
163+ pyproject_toml_path : Path | None = getattr (params , "config_file" , None ) or getattr ( server . args , "config_file" , None )
161164
162- if server .args is None :
163- if pyproject_toml_path is not None :
164- # if there is a config file provided use it
165+ if pyproject_toml_path is not None :
166+ # if there is a config file provided use it
167+ server .prepare_optimizer_arguments (pyproject_toml_path )
168+ else :
169+ # otherwise look for it
170+ pyproject_toml_path , has_codeflash_config = _find_pyproject_toml (params .root_path_abs )
171+ if pyproject_toml_path and has_codeflash_config :
172+ server .show_message_log (f"Found pyproject.toml at: { pyproject_toml_path } " , "Info" )
165173 server .prepare_optimizer_arguments (pyproject_toml_path )
174+ elif pyproject_toml_path and not has_codeflash_config :
175+ return {
176+ "status" : "error" ,
177+ "message" : "pyproject.toml found in workspace, but no codeflash config." ,
178+ "pyprojectPath" : pyproject_toml_path ,
179+ }
166180 else :
167- # otherwise look for it
168- pyproject_toml_path = _find_pyproject_toml (params .root_path_abs )
169- server .show_message_log (f"Found pyproject.toml at: { pyproject_toml_path } " , "Info" )
170- if pyproject_toml_path :
171- server .prepare_optimizer_arguments (pyproject_toml_path )
172- else :
173- return {"status" : "error" , "message" : "No pyproject.toml found in workspace." }
181+ return {"status" : "error" , "message" : "No pyproject.toml found in workspace." }
174182
175183 # since we are using worktrees, optimization diffs are generated with respect to the root of the repo.
176184 root = str (git_root_dir ())
@@ -187,10 +195,7 @@ def init_project(server: CodeflashLanguageServer, params: ValidateProjectParams)
187195 config = is_valid_pyproject_toml (pyproject_toml_path )
188196 if config is None :
189197 server .show_message_log ("pyproject.toml is not valid" , "Error" )
190- return {
191- "status" : "error" ,
192- "message" : "pyproject.toml is not valid" , # keep the error message the same, the extension is matching "pyproject.toml" in the error message to show the codeflash init instructions,
193- }
198+ return {"status" : "error" , "message" : "not valid" , "pyprojectPath" : pyproject_toml_path }
194199
195200 args = process_args (server )
196201 repo = git .Repo (args .module_root , search_parent_directories = True )
0 commit comments