Skip to content

Commit a3907d8

Browse files
authored
Merge branch 'main' into worktree/mirror-all-arg
2 parents e457565 + 952bb6d commit a3907d8

File tree

4 files changed

+123
-27
lines changed

4 files changed

+123
-27
lines changed

codeflash/api/aiservice.py

Lines changed: 85 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,27 +4,30 @@
44
import os
55
import platform
66
import time
7-
from typing import TYPE_CHECKING, Any
7+
from pathlib import Path
8+
from typing import TYPE_CHECKING, Any, cast
89

910
import requests
1011
from pydantic.json import pydantic_encoder
1112

1213
from codeflash.cli_cmds.console import console, logger
14+
from codeflash.code_utils.code_replacer import is_zero_diff
15+
from codeflash.code_utils.code_utils import unified_diff_strings
1316
from codeflash.code_utils.config_consts import N_CANDIDATES_EFFECTIVE, N_CANDIDATES_LP_EFFECTIVE
1417
from codeflash.code_utils.env_utils import get_codeflash_api_key
1518
from codeflash.code_utils.git_utils import get_last_commit_author_if_pr_exists, get_repo_owner_and_name
19+
from codeflash.code_utils.time_utils import humanize_runtime
1620
from codeflash.lsp.helpers import is_LSP_enabled
1721
from codeflash.models.ExperimentMetadata import ExperimentMetadata
1822
from codeflash.models.models import AIServiceRefinerRequest, CodeStringsMarkdown, OptimizedCandidate
1923
from codeflash.telemetry.posthog_cf import ph
2024
from codeflash.version import __version__ as codeflash_version
2125

2226
if TYPE_CHECKING:
23-
from pathlib import Path
24-
2527
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
2628
from codeflash.models.ExperimentMetadata import ExperimentMetadata
2729
from codeflash.models.models import AIServiceRefinerRequest
30+
from codeflash.result.explanation import Explanation
2831

2932

3033
class AiServiceClient:
@@ -529,6 +532,85 @@ def generate_regression_tests( # noqa: D417
529532
ph("cli-testgen-error-response", {"response_status_code": response.status_code, "error": response.text})
530533
return None
531534

535+
def get_optimization_impact(
536+
self,
537+
original_code: dict[Path, str],
538+
new_code: dict[Path, str],
539+
explanation: Explanation,
540+
existing_tests_source: str,
541+
generated_original_test_source: str,
542+
function_trace_id: str,
543+
coverage_message: str,
544+
replay_tests: str,
545+
root_dir: Path,
546+
concolic_tests: str, # noqa: ARG002
547+
) -> str:
548+
"""Compute the optimization impact of current Pull Request.
549+
550+
Args:
551+
original_code: dict -> data structure mapping file paths to function definition for original code
552+
new_code: dict -> data structure mapping file paths to function definition for optimized code
553+
explanation: Explanation -> data structure containing runtime information
554+
existing_tests_source: str -> existing tests table
555+
generated_original_test_source: str -> annotated generated tests
556+
function_trace_id: str -> traceid of function
557+
coverage_message: str -> coverage information
558+
replay_tests: str -> replay test table
559+
root_dir: Path -> path of git directory
560+
concolic_tests: str -> concolic_tests (not used)
561+
562+
Returns:
563+
-------
564+
- 'high' or 'low' optimization impact
565+
566+
"""
567+
diff_str = "\n".join(
568+
[
569+
unified_diff_strings(
570+
code1=original_code[p],
571+
code2=new_code[p],
572+
fromfile=Path(p).relative_to(root_dir).as_posix(),
573+
tofile=Path(p).relative_to(root_dir).as_posix(),
574+
)
575+
for p in original_code
576+
if not is_zero_diff(original_code[p], new_code[p])
577+
]
578+
)
579+
code_diff = f"```diff\n{diff_str}\n```"
580+
logger.info("!lsp|Computing Optimization Impact…")
581+
payload = {
582+
"code_diff": code_diff,
583+
"explanation": explanation.raw_explanation_message,
584+
"existing_tests": existing_tests_source,
585+
"generated_tests": generated_original_test_source,
586+
"trace_id": function_trace_id,
587+
"coverage_message": coverage_message,
588+
"replay_tests": replay_tests,
589+
"speedup": f"{(100 * float(explanation.speedup)):.2f}%",
590+
"loop_count": explanation.winning_benchmarking_test_results.number_of_loops(),
591+
"benchmark_details": explanation.benchmark_details if explanation.benchmark_details else None,
592+
"optimized_runtime": humanize_runtime(explanation.best_runtime_ns),
593+
"original_runtime": humanize_runtime(explanation.original_runtime_ns),
594+
}
595+
console.rule()
596+
try:
597+
response = self.make_ai_service_request("/optimization_impact", payload=payload, timeout=600)
598+
except requests.exceptions.RequestException as e:
599+
logger.exception(f"Error generating optimization refinements: {e}")
600+
ph("cli-optimize-error-caught", {"error": str(e)})
601+
return ""
602+
603+
if response.status_code == 200:
604+
return cast("str", response.json()["impact"])
605+
try:
606+
error = cast("str", response.json()["error"])
607+
except Exception:
608+
error = response.text
609+
logger.error(f"Error generating impact candidates: {response.status_code} - {error}")
610+
ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error})
611+
console.rule()
612+
return ""
613+
532614

533615
class LocalAiServiceClient(AiServiceClient):
534616
"""Client for interacting with the local AI service."""

codeflash/discovery/discover_unit_tests.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,7 @@ def discover_tests_pytest(
417417
with tmp_pickle_path.open(mode="rb") as f:
418418
exitcode, tests, pytest_rootdir = pickle.load(f)
419419
except Exception as e:
420+
tests, pytest_rootdir = [], None
420421
logger.exception(f"Failed to discover tests: {e}")
421422
exitcode = -1
422423
finally:

codeflash/lsp/beta.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -131,10 +131,11 @@ def get_optimizable_functions(
131131
return path_to_qualified_names
132132

133133

134-
def _find_pyproject_toml(workspace_path: str) -> Path | None:
134+
def _find_pyproject_toml(workspace_path: str) -> tuple[Path | None, bool]:
135135
workspace_path_obj = Path(workspace_path)
136136
max_depth = 2
137137
base_depth = len(workspace_path_obj.parts)
138+
top_level_pyproject = None
138139

139140
for root, dirs, files in os.walk(workspace_path_obj):
140141
depth = len(Path(root).parts) - base_depth
@@ -145,32 +146,39 @@ def _find_pyproject_toml(workspace_path: str) -> Path | None:
145146

146147
if "pyproject.toml" in files:
147148
file_path = Path(root) / "pyproject.toml"
149+
if depth == 0:
150+
top_level_pyproject = file_path
148151
with file_path.open("r", encoding="utf-8", errors="ignore") as f:
149152
for line in f:
150153
if line.strip() == "[tool.codeflash]":
151-
return file_path.resolve()
152-
return None
154+
return file_path.resolve(), True
155+
return top_level_pyproject, False
153156

154157

155158
# should be called the first thing to initialize and validate the project
156159
@server.feature("initProject")
157-
def init_project(server: CodeflashLanguageServer, params: ValidateProjectParams) -> dict[str, str]:
160+
def init_project(server: CodeflashLanguageServer, params: ValidateProjectParams) -> dict[str, str]: # noqa: PLR0911
158161
from codeflash.cli_cmds.cmd_init import is_valid_pyproject_toml
159162

160-
pyproject_toml_path: Path | None = getattr(params, "config_file", None)
163+
pyproject_toml_path: Path | None = getattr(params, "config_file", None) or getattr(server.args, "config_file", None)
161164

162-
if server.args is None:
163-
if pyproject_toml_path is not None:
164-
# if there is a config file provided use it
165+
if pyproject_toml_path is not None:
166+
# if there is a config file provided use it
167+
server.prepare_optimizer_arguments(pyproject_toml_path)
168+
else:
169+
# otherwise look for it
170+
pyproject_toml_path, has_codeflash_config = _find_pyproject_toml(params.root_path_abs)
171+
if pyproject_toml_path and has_codeflash_config:
172+
server.show_message_log(f"Found pyproject.toml at: {pyproject_toml_path}", "Info")
165173
server.prepare_optimizer_arguments(pyproject_toml_path)
174+
elif pyproject_toml_path and not has_codeflash_config:
175+
return {
176+
"status": "error",
177+
"message": "pyproject.toml found in workspace, but no codeflash config.",
178+
"pyprojectPath": pyproject_toml_path,
179+
}
166180
else:
167-
# otherwise look for it
168-
pyproject_toml_path = _find_pyproject_toml(params.root_path_abs)
169-
server.show_message_log(f"Found pyproject.toml at: {pyproject_toml_path}", "Info")
170-
if pyproject_toml_path:
171-
server.prepare_optimizer_arguments(pyproject_toml_path)
172-
else:
173-
return {"status": "error", "message": "No pyproject.toml found in workspace."}
181+
return {"status": "error", "message": "No pyproject.toml found in workspace."}
174182

175183
# since we are using worktrees, optimization diffs are generated with respect to the root of the repo.
176184
root = str(git_root_dir())
@@ -187,10 +195,7 @@ def init_project(server: CodeflashLanguageServer, params: ValidateProjectParams)
187195
config = is_valid_pyproject_toml(pyproject_toml_path)
188196
if config is None:
189197
server.show_message_log("pyproject.toml is not valid", "Error")
190-
return {
191-
"status": "error",
192-
"message": "pyproject.toml is not valid", # keep the error message the same, the extension is matching "pyproject.toml" in the error message to show the codeflash init instructions,
193-
}
198+
return {"status": "error", "message": "not valid", "pyprojectPath": pyproject_toml_path}
194199

195200
args = process_args(server)
196201
repo = git.Repo(args.module_root, search_parent_directories=True)

codeflash/optimization/function_optimizer.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1458,14 +1458,22 @@ def process_review(
14581458
}
14591459

14601460
raise_pr = not self.args.no_pr
1461+
staging_review = self.args.staging_review
14611462

1462-
if raise_pr or self.args.staging_review:
1463+
if raise_pr or staging_review:
14631464
data["root_dir"] = git_root_dir()
1464-
1465-
if raise_pr and not self.args.staging_review:
1465+
try:
1466+
# modify argument of staging vs pr based on the impact
1467+
opt_impact_response = self.aiservice_client.get_optimization_impact(**data)
1468+
if opt_impact_response == "low":
1469+
raise_pr = False
1470+
staging_review = True
1471+
except Exception as e:
1472+
logger.debug(f"optimization impact response failed, investigate {e}")
1473+
if raise_pr and not staging_review:
14661474
data["git_remote"] = self.args.git_remote
14671475
check_create_pr(**data)
1468-
elif self.args.staging_review:
1476+
elif staging_review:
14691477
response = create_staging(**data)
14701478
if response.status_code == 200:
14711479
staging_url = f"https://app.codeflash.ai/review-optimizations/{self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id}"
@@ -1504,7 +1512,7 @@ def process_review(
15041512
self.revert_code_and_helpers(original_helper_code)
15051513
return
15061514

1507-
if self.args.staging_review:
1515+
if staging_review:
15081516
# always revert code and helpers when staging review
15091517
self.revert_code_and_helpers(original_helper_code)
15101518
return

0 commit comments

Comments
 (0)