Skip to content

Commit d3e6427

Browse files
[LSP] stderr verbose logs for the thought process (#718)
* lsp silent logs * override other log methods and log serialized lsp messages * send the module root to the lsp client * lsp messages * code print over lsp * more enhancements * more enhancements * cf optimization * log tags for lsp * better markdown support for lsp message logging * simple markdown table * force lsp log (tag) * fixes for cli console logs * small fixes * small fix * it should work this time * prevent worktree log in lsp * logging enhancement * file name for best candidate * reminder * lsp logs formatting and small fixes * typo * fixes for the api key and the lsp gracefull shutdown --------- Co-authored-by: Sarthak Agarwal <sarthak.saga@gmail.com>
1 parent 4856cce commit d3e6427

File tree

17 files changed

+620
-243
lines changed

17 files changed

+620
-243
lines changed

codeflash/api/aiservice.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def optimize_python_code( # noqa: D417
133133
"repo_name": git_repo_name,
134134
}
135135

136-
logger.info("Generating optimized candidates…")
136+
logger.info("!lsp|Generating optimized candidates…")
137137
console.rule()
138138
try:
139139
response = self.make_ai_service_request("/optimize", payload=payload, timeout=600)
@@ -144,10 +144,10 @@ def optimize_python_code( # noqa: D417
144144

145145
if response.status_code == 200:
146146
optimizations_json = response.json()["optimizations"]
147-
logger.info(f"Generated {len(optimizations_json)} candidate optimizations.")
147+
logger.info(f"!lsp|Generated {len(optimizations_json)} candidate optimizations.")
148148
console.rule()
149149
end_time = time.perf_counter()
150-
logger.debug(f"Generating optimizations took {end_time - start_time:.2f} seconds.")
150+
logger.debug(f"!lsp|Generating possible optimizations took {end_time - start_time:.2f} seconds.")
151151
return self._get_valid_candidates(optimizations_json)
152152
try:
153153
error = response.json()["error"]
@@ -194,7 +194,6 @@ def optimize_python_code_line_profiler( # noqa: D417
194194
"lsp_mode": is_LSP_enabled(),
195195
}
196196

197-
logger.info("Generating optimized candidates…")
198197
console.rule()
199198
if line_profiler_results == "":
200199
logger.info("No LineProfiler results were provided, Skipping optimization.")
@@ -209,7 +208,9 @@ def optimize_python_code_line_profiler( # noqa: D417
209208

210209
if response.status_code == 200:
211210
optimizations_json = response.json()["optimizations"]
212-
logger.info(f"Generated {len(optimizations_json)} candidate optimizations using line profiler information.")
211+
logger.info(
212+
f"!lsp|Generated {len(optimizations_json)} candidate optimizations using line profiler information."
213+
)
213214
console.rule()
214215
return self._get_valid_candidates(optimizations_json)
215216
try:
@@ -331,7 +332,7 @@ def get_new_explanation( # noqa: D417
331332
"original_explanation": original_explanation,
332333
"dependency_code": dependency_code,
333334
}
334-
logger.info("Generating explanation")
335+
logger.info("loading|Generating explanation")
335336
console.rule()
336337
try:
337338
response = self.make_ai_service_request("/explain", payload=payload, timeout=60)
@@ -376,7 +377,7 @@ def generate_ranking( # noqa: D417
376377
"optimization_ids": optimization_ids,
377378
"python_version": platform.python_version(),
378379
}
379-
logger.info("Generating ranking")
380+
logger.info("loading|Generating ranking")
380381
console.rule()
381382
try:
382383
response = self.make_ai_service_request("/rank", payload=payload, timeout=60)

codeflash/api/cfapi.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def make_cfapi_request(
4040
payload: dict[str, Any] | None = None,
4141
extra_headers: dict[str, str] | None = None,
4242
*,
43+
api_key: str | None = None,
4344
suppress_errors: bool = False,
4445
) -> Response:
4546
"""Make an HTTP request using the specified method, URL, headers, and JSON payload.
@@ -51,7 +52,7 @@ def make_cfapi_request(
5152
:return: The response object from the API.
5253
"""
5354
url = f"{CFAPI_BASE_URL}/cfapi{endpoint}"
54-
cfapi_headers = {"Authorization": f"Bearer {get_codeflash_api_key()}"}
55+
cfapi_headers = {"Authorization": f"Bearer {api_key or get_codeflash_api_key()}"}
5556
if extra_headers:
5657
cfapi_headers.update(extra_headers)
5758
try:
@@ -83,15 +84,17 @@ def make_cfapi_request(
8384

8485

8586
@lru_cache(maxsize=1)
86-
def get_user_id() -> Optional[str]:
87+
def get_user_id(api_key: Optional[str] = None) -> Optional[str]:
8788
"""Retrieve the user's userid by making a request to the /cfapi/cli-get-user endpoint.
8889
8990
:return: The userid or None if the request fails.
9091
"""
9192
if not ensure_codeflash_api_key():
9293
return None
9394

94-
response = make_cfapi_request(endpoint="/cli-get-user", method="GET", extra_headers={"cli_version": __version__})
95+
response = make_cfapi_request(
96+
endpoint="/cli-get-user", method="GET", extra_headers={"cli_version": __version__}, api_key=api_key
97+
)
9598
if response.status_code == 200:
9699
if "min_version" not in response.text:
97100
return response.text

codeflash/cli_cmds/console.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
from __future__ import annotations
22

33
import logging
4-
import os
54
from contextlib import contextmanager
65
from itertools import cycle
7-
from typing import TYPE_CHECKING
6+
from typing import TYPE_CHECKING, Optional
87

98
from rich.console import Console
109
from rich.logging import RichHandler
@@ -20,17 +19,22 @@
2019

2120
from codeflash.cli_cmds.console_constants import SPINNER_TYPES
2221
from codeflash.cli_cmds.logging_config import BARE_LOGGING_FORMAT
22+
from codeflash.lsp.helpers import is_LSP_enabled
23+
from codeflash.lsp.lsp_logger import enhanced_log
24+
from codeflash.lsp.lsp_message import LspCodeMessage, LspTextMessage
2325

2426
if TYPE_CHECKING:
2527
from collections.abc import Generator
2628

2729
from rich.progress import TaskID
2830

31+
from codeflash.lsp.lsp_message import LspMessage
32+
2933
DEBUG_MODE = logging.getLogger().getEffectiveLevel() == logging.DEBUG
3034

3135
console = Console()
3236

33-
if os.getenv("CODEFLASH_LSP"):
37+
if is_LSP_enabled():
3438
console.quiet = True
3539

3640
logging.basicConfig(
@@ -42,6 +46,24 @@
4246
logger = logging.getLogger("rich")
4347
logging.getLogger("parso").setLevel(logging.WARNING)
4448

49+
# override the logger to reformat the messages for the lsp
50+
for level in ("info", "debug", "warning", "error"):
51+
real_fn = getattr(logger, level)
52+
setattr(
53+
logger,
54+
level,
55+
lambda msg, *args, _real_fn=real_fn, _level=level, **kwargs: enhanced_log(
56+
msg, _real_fn, _level, *args, **kwargs
57+
),
58+
)
59+
60+
61+
def lsp_log(message: LspMessage) -> None:
62+
if not is_LSP_enabled():
63+
return
64+
json_msg = message.serialize()
65+
logger.info(json_msg)
66+
4567

4668
def paneled_text(
4769
text: str, panel_args: dict[str, str | bool] | None = None, text_args: dict[str, str] | None = None
@@ -58,7 +80,10 @@ def paneled_text(
5880
console.print(panel)
5981

6082

61-
def code_print(code_str: str) -> None:
83+
def code_print(code_str: str, file_name: Optional[str] = None, function_name: Optional[str] = None) -> None:
84+
if is_LSP_enabled():
85+
lsp_log(LspCodeMessage(code=code_str, file_name=file_name, function_name=function_name))
86+
return
6287
"""Print code with syntax highlighting."""
6388
from rich.syntax import Syntax
6489

@@ -79,6 +104,11 @@ def progress_bar(
79104
If revert_to_print is True, falls back to printing a single logger.info message
80105
instead of showing a progress bar.
81106
"""
107+
if is_LSP_enabled():
108+
lsp_log(LspTextMessage(text=message, takes_time=True))
109+
yield
110+
return
111+
82112
if revert_to_print:
83113
logger.info(message)
84114

codeflash/code_utils/git_worktree_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def create_detached_worktree(module_root: Path) -> Optional[Path]:
6262
)
6363

6464
if not uni_diff_text.strip():
65-
logger.info("No uncommitted changes to copy to worktree.")
65+
logger.info("!lsp|No uncommitted changes to copy to worktree.")
6666
return worktree_dir
6767

6868
# Write the diff to a temporary file

codeflash/discovery/functions_to_optimize.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -173,15 +173,15 @@ def get_functions_to_optimize(
173173
with warnings.catch_warnings():
174174
warnings.simplefilter(action="ignore", category=SyntaxWarning)
175175
if optimize_all:
176-
logger.info("Finding all functions in the module '%s'…", optimize_all)
176+
logger.info("!lsp|Finding all functions in the module '%s'…", optimize_all)
177177
console.rule()
178178
functions = get_all_files_and_functions(Path(optimize_all))
179179
elif replay_test:
180180
functions, trace_file_path = get_all_replay_test_functions(
181181
replay_test=replay_test, test_cfg=test_cfg, project_root_path=project_root
182182
)
183183
elif file is not None:
184-
logger.info("Finding all functions in the file '%s'…", file)
184+
logger.info("!lsp|Finding all functions in the file '%s'…", file)
185185
console.rule()
186186
functions = find_all_functions_in_file(file)
187187
if only_get_this_function is not None:
@@ -219,7 +219,7 @@ def get_functions_to_optimize(
219219
functions, test_cfg.tests_root, ignore_paths, project_root, module_root, previous_checkpoint_functions
220220
)
221221

222-
logger.info(f"Found {functions_count} function{'s' if functions_count > 1 else ''} to optimize")
222+
logger.info(f"!lsp|Found {functions_count} function{'s' if functions_count > 1 else ''} to optimize")
223223
if optimize_all:
224224
three_min_in_ns = int(1.8e11)
225225
console.rule()

codeflash/lsp/beta.py

Lines changed: 35 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@
44
import os
55
from dataclasses import dataclass
66
from pathlib import Path
7-
from typing import TYPE_CHECKING
7+
from typing import TYPE_CHECKING, Optional
88

99
import git
1010
from pygls import uris
1111

1212
from codeflash.api.cfapi import get_codeflash_api_key, get_user_id
1313
from codeflash.cli_cmds.cli import process_pyproject_config
14+
from codeflash.cli_cmds.console import code_print
1415
from codeflash.code_utils.git_worktree_utils import (
1516
create_diff_patch_from_worktree,
1617
get_patches_metadata,
@@ -103,6 +104,8 @@ def get_optimizable_functions(
103104
) -> dict[str, list[str]]:
104105
file_path = Path(uris.to_fs_path(params.textDocument.uri))
105106
server.show_message_log(f"Getting optimizable functions for: {file_path}", "Info")
107+
if not server.optimizer:
108+
return {"status": "error", "message": "optimizer not initialized"}
106109

107110
server.optimizer.args.file = file_path
108111
server.optimizer.args.function = None # Always get ALL functions, not just one
@@ -157,20 +160,6 @@ def initialize_function_optimization(
157160
return {"functionName": params.functionName, "status": "success"}
158161

159162

160-
@server.feature("discoverFunctionTests")
161-
def discover_function_tests(server: CodeflashLanguageServer, params: FunctionOptimizationParams) -> dict[str, str]:
162-
fto = server.optimizer.current_function_being_optimized
163-
optimizable_funcs = {fto.file_path: [fto]}
164-
165-
devnull_writer = open(os.devnull, "w") # noqa
166-
with contextlib.redirect_stdout(devnull_writer):
167-
function_to_tests, num_discovered_tests = server.optimizer.discover_tests(optimizable_funcs)
168-
169-
server.optimizer.discovered_tests = function_to_tests
170-
171-
return {"functionName": params.functionName, "status": "success", "discovered_tests": num_discovered_tests}
172-
173-
174163
@server.feature("validateProject")
175164
def validate_project(server: CodeflashLanguageServer, _params: FunctionOptimizationParams) -> dict[str, str]:
176165
from codeflash.cli_cmds.cmd_init import is_valid_pyproject_toml
@@ -194,11 +183,13 @@ def validate_project(server: CodeflashLanguageServer, _params: FunctionOptimizat
194183
except Exception:
195184
return {"status": "error", "message": "Repository has no commits (unborn HEAD)"}
196185

197-
return {"status": "success"}
186+
return {"status": "success", "moduleRoot": args.module_root}
198187

199188

200-
def _initialize_optimizer_if_api_key_is_valid(server: CodeflashLanguageServer) -> dict[str, str]:
201-
user_id = get_user_id()
189+
def _initialize_optimizer_if_api_key_is_valid(
190+
server: CodeflashLanguageServer, api_key: Optional[str] = None
191+
) -> dict[str, str]:
192+
user_id = get_user_id(api_key=api_key)
202193
if user_id is None:
203194
return {"status": "error", "message": "api key not found or invalid"}
204195

@@ -237,19 +228,19 @@ def provide_api_key(server: CodeflashLanguageServer, params: ProvideApiKeyParams
237228
if not api_key.startswith("cf-"):
238229
return {"status": "error", "message": "Api key is not valid"}
239230

240-
result = save_api_key_to_rc(api_key)
241-
if not is_successful(result):
242-
return {"status": "error", "message": result.failure()}
243-
244231
# clear cache to ensure the new api key is used
245232
get_codeflash_api_key.cache_clear()
246233
get_user_id.cache_clear()
247234

248-
init_result = _initialize_optimizer_if_api_key_is_valid(server)
235+
init_result = _initialize_optimizer_if_api_key_is_valid(server, api_key)
249236
if init_result["status"] == "error":
250237
return {"status": "error", "message": "Api key is not valid"}
251238

252-
return {"status": "success", "message": "Api key saved successfully", "user_id": init_result["user_id"]}
239+
user_id = init_result["user_id"]
240+
result = save_api_key_to_rc(api_key)
241+
if not is_successful(result):
242+
return {"status": "error", "message": result.failure()}
243+
return {"status": "success", "message": "Api key saved successfully", "user_id": user_id} # noqa: TRY300
253244
except Exception:
254245
return {"status": "error", "message": "something went wrong while saving the api key"}
255246

@@ -300,6 +291,12 @@ def perform_function_optimization( # noqa: PLR0911
300291
}
301292

302293
module_prep_result = server.optimizer.prepare_module_for_optimization(current_function.file_path)
294+
if not module_prep_result:
295+
return {
296+
"functionName": params.functionName,
297+
"status": "error",
298+
"message": "Failed to prepare module for optimization",
299+
}
303300

304301
validated_original_code, original_module_ast = module_prep_result
305302

@@ -308,7 +305,7 @@ def perform_function_optimization( # noqa: PLR0911
308305
function_to_optimize_source_code=validated_original_code[current_function.file_path].source_code,
309306
original_module_ast=original_module_ast,
310307
original_module_path=current_function.file_path,
311-
function_to_tests=server.optimizer.discovered_tests or {},
308+
function_to_tests={},
312309
)
313310

314311
server.optimizer.current_function_optimizer = function_optimizer
@@ -321,6 +318,19 @@ def perform_function_optimization( # noqa: PLR0911
321318

322319
should_run_experiment, code_context, original_helper_code = initialization_result.unwrap()
323320

321+
code_print(
322+
code_context.read_writable_code.flat,
323+
file_name=current_function.file_path,
324+
function_name=current_function.function_name,
325+
)
326+
327+
optimizable_funcs = {current_function.file_path: [current_function]}
328+
329+
devnull_writer = open(os.devnull, "w") # noqa
330+
with contextlib.redirect_stdout(devnull_writer):
331+
function_to_tests, num_discovered_tests = server.optimizer.discover_tests(optimizable_funcs)
332+
function_optimizer.function_to_tests = function_to_tests
333+
324334
test_setup_result = function_optimizer.generate_and_instrument_tests(
325335
code_context, should_run_experiment=should_run_experiment
326336
)

0 commit comments

Comments
 (0)