44import os
55from dataclasses import dataclass
66from pathlib import Path
7- from typing import TYPE_CHECKING
7+ from typing import TYPE_CHECKING , Optional
88
99import git
1010from pygls import uris
1111
1212from codeflash .api .cfapi import get_codeflash_api_key , get_user_id
1313from codeflash .cli_cmds .cli import process_pyproject_config
14+ from codeflash .cli_cmds .console import code_print
1415from codeflash .code_utils .git_worktree_utils import (
1516 create_diff_patch_from_worktree ,
1617 get_patches_metadata ,
@@ -103,6 +104,8 @@ def get_optimizable_functions(
103104) -> dict [str , list [str ]]:
104105 file_path = Path (uris .to_fs_path (params .textDocument .uri ))
105106 server .show_message_log (f"Getting optimizable functions for: { file_path } " , "Info" )
107+ if not server .optimizer :
108+ return {"status" : "error" , "message" : "optimizer not initialized" }
106109
107110 server .optimizer .args .file = file_path
108111 server .optimizer .args .function = None # Always get ALL functions, not just one
@@ -157,20 +160,6 @@ def initialize_function_optimization(
157160 return {"functionName" : params .functionName , "status" : "success" }
158161
159162
160- @server .feature ("discoverFunctionTests" )
161- def discover_function_tests (server : CodeflashLanguageServer , params : FunctionOptimizationParams ) -> dict [str , str ]:
162- fto = server .optimizer .current_function_being_optimized
163- optimizable_funcs = {fto .file_path : [fto ]}
164-
165- devnull_writer = open (os .devnull , "w" ) # noqa
166- with contextlib .redirect_stdout (devnull_writer ):
167- function_to_tests , num_discovered_tests = server .optimizer .discover_tests (optimizable_funcs )
168-
169- server .optimizer .discovered_tests = function_to_tests
170-
171- return {"functionName" : params .functionName , "status" : "success" , "discovered_tests" : num_discovered_tests }
172-
173-
174163@server .feature ("validateProject" )
175164def validate_project (server : CodeflashLanguageServer , _params : FunctionOptimizationParams ) -> dict [str , str ]:
176165 from codeflash .cli_cmds .cmd_init import is_valid_pyproject_toml
@@ -194,11 +183,13 @@ def validate_project(server: CodeflashLanguageServer, _params: FunctionOptimizat
194183 except Exception :
195184 return {"status" : "error" , "message" : "Repository has no commits (unborn HEAD)" }
196185
197- return {"status" : "success" }
186+ return {"status" : "success" , "moduleRoot" : args . module_root }
198187
199188
200- def _initialize_optimizer_if_api_key_is_valid (server : CodeflashLanguageServer ) -> dict [str , str ]:
201- user_id = get_user_id ()
189+ def _initialize_optimizer_if_api_key_is_valid (
190+ server : CodeflashLanguageServer , api_key : Optional [str ] = None
191+ ) -> dict [str , str ]:
192+ user_id = get_user_id (api_key = api_key )
202193 if user_id is None :
203194 return {"status" : "error" , "message" : "api key not found or invalid" }
204195
@@ -237,19 +228,19 @@ def provide_api_key(server: CodeflashLanguageServer, params: ProvideApiKeyParams
237228 if not api_key .startswith ("cf-" ):
238229 return {"status" : "error" , "message" : "Api key is not valid" }
239230
240- result = save_api_key_to_rc (api_key )
241- if not is_successful (result ):
242- return {"status" : "error" , "message" : result .failure ()}
243-
244231 # clear cache to ensure the new api key is used
245232 get_codeflash_api_key .cache_clear ()
246233 get_user_id .cache_clear ()
247234
248- init_result = _initialize_optimizer_if_api_key_is_valid (server )
235+ init_result = _initialize_optimizer_if_api_key_is_valid (server , api_key )
249236 if init_result ["status" ] == "error" :
250237 return {"status" : "error" , "message" : "Api key is not valid" }
251238
252- return {"status" : "success" , "message" : "Api key saved successfully" , "user_id" : init_result ["user_id" ]}
239+ user_id = init_result ["user_id" ]
240+ result = save_api_key_to_rc (api_key )
241+ if not is_successful (result ):
242+ return {"status" : "error" , "message" : result .failure ()}
243+ return {"status" : "success" , "message" : "Api key saved successfully" , "user_id" : user_id } # noqa: TRY300
253244 except Exception :
254245 return {"status" : "error" , "message" : "something went wrong while saving the api key" }
255246
@@ -300,6 +291,12 @@ def perform_function_optimization( # noqa: PLR0911
300291 }
301292
302293 module_prep_result = server .optimizer .prepare_module_for_optimization (current_function .file_path )
294+ if not module_prep_result :
295+ return {
296+ "functionName" : params .functionName ,
297+ "status" : "error" ,
298+ "message" : "Failed to prepare module for optimization" ,
299+ }
303300
304301 validated_original_code , original_module_ast = module_prep_result
305302
@@ -308,7 +305,7 @@ def perform_function_optimization( # noqa: PLR0911
308305 function_to_optimize_source_code = validated_original_code [current_function .file_path ].source_code ,
309306 original_module_ast = original_module_ast ,
310307 original_module_path = current_function .file_path ,
311- function_to_tests = server . optimizer . discovered_tests or {},
308+ function_to_tests = {},
312309 )
313310
314311 server .optimizer .current_function_optimizer = function_optimizer
@@ -321,6 +318,19 @@ def perform_function_optimization( # noqa: PLR0911
321318
322319 should_run_experiment , code_context , original_helper_code = initialization_result .unwrap ()
323320
321+ code_print (
322+ code_context .read_writable_code .flat ,
323+ file_name = current_function .file_path ,
324+ function_name = current_function .function_name ,
325+ )
326+
327+ optimizable_funcs = {current_function .file_path : [current_function ]}
328+
329+ devnull_writer = open (os .devnull , "w" ) # noqa
330+ with contextlib .redirect_stdout (devnull_writer ):
331+ function_to_tests , num_discovered_tests = server .optimizer .discover_tests (optimizable_funcs )
332+ function_optimizer .function_to_tests = function_to_tests
333+
324334 test_setup_result = function_optimizer .generate_and_instrument_tests (
325335 code_context , should_run_experiment = should_run_experiment
326336 )
0 commit comments