Skip to content

Commit e525edc

Browse files
authored
Merge branch 'main' into 3.14-in-CI
2 parents bf633fe + ffd8c90 commit e525edc

17 files changed

+576
-189
lines changed

.github/workflows/e2e-init-optimization.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ jobs:
1919
COLUMNS: 110
2020
MAX_RETRIES: 3
2121
RETRY_DELAY: 5
22-
EXPECTED_IMPROVEMENT_PCT: 30
22+
EXPECTED_IMPROVEMENT_PCT: 10
2323
CODEFLASH_END_TO_END: 1
2424
steps:
2525
- name: 🛎️ Checkout

.github/workflows/e2e-topological-sort.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: E2E - Topological Sort
1+
name: E2E - Topological Sort (Worktree)
22

33
on:
44
pull_request:
@@ -8,7 +8,7 @@ on:
88
workflow_dispatch:
99

1010
jobs:
11-
topological-sort-optimization:
11+
topological-sort-worktree-optimization:
1212
# Dynamically determine if environment is needed only when workflow files change and contributor is external
1313
environment: ${{ (github.event_name == 'workflow_dispatch' || (contains(toJSON(github.event.pull_request.files.*.filename), '.github/workflows/') && github.event.pull_request.user.login != 'misrasaurabh1' && github.event.pull_request.user.login != 'KRRT7')) && 'external-trusted-contributors' || '' }}
1414
runs-on: ubuntu-latest
@@ -90,4 +90,4 @@ jobs:
9090
- name: Run Codeflash to optimize code
9191
id: optimize_code
9292
run: |
93-
uv run python tests/scripts/end_to_end_test_topological_sort.py
93+
uv run python tests/scripts/end_to_end_test_topological_sort_worktree.py

codeflash/cli_cmds/cmd_init.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -155,20 +155,30 @@ def ask_run_end_to_end_test(args: Namespace) -> None:
155155
run_end_to_end_test(args, bubble_sort_path, bubble_sort_test_path)
156156

157157

158-
def is_valid_pyproject_toml(pyproject_toml_path: Path) -> dict[str, Any] | None:
158+
def is_valid_pyproject_toml(pyproject_toml_path: Path) -> tuple[dict[str, Any] | None, str]: # noqa: PLR0911
159159
if not pyproject_toml_path.exists():
160-
return None
160+
return None, f"Configuration file not found: {pyproject_toml_path}"
161+
161162
try:
162163
config, _ = parse_config_file(pyproject_toml_path)
163-
except Exception:
164-
return None
164+
except Exception as e:
165+
return None, f"Failed to parse configuration: {e}"
166+
167+
module_root = config.get("module_root")
168+
if not module_root:
169+
return None, "Missing required field: 'module_root'"
170+
171+
if not Path(module_root).is_dir():
172+
return None, f"Invalid 'module_root': directory does not exist at {module_root}"
173+
174+
tests_root = config.get("tests_root")
175+
if not tests_root:
176+
return None, "Missing required field: 'tests_root'"
165177

166-
if "module_root" not in config or config["module_root"] is None or not Path(config["module_root"]).is_dir():
167-
return None
168-
if "tests_root" not in config or config["tests_root"] is None or not Path(config["tests_root"]).is_dir():
169-
return None
178+
if not Path(tests_root).is_dir():
179+
return None, f"Invalid 'tests_root': directory does not exist at {tests_root}"
170180

171-
return config
181+
return config, ""
172182

173183

174184
def should_modify_pyproject_toml() -> tuple[bool, dict[str, Any] | None]:
@@ -180,7 +190,7 @@ def should_modify_pyproject_toml() -> tuple[bool, dict[str, Any] | None]:
180190

181191
pyproject_toml_path = Path.cwd() / "pyproject.toml"
182192

183-
config = is_valid_pyproject_toml(pyproject_toml_path)
193+
config, _message = is_valid_pyproject_toml(pyproject_toml_path)
184194
if config is None:
185195
return True, None
186196

@@ -199,9 +209,7 @@ def __init__(self) -> None:
199209
self.Question.brackets_color = inquirer.themes.term.bright_blue
200210
self.Question.default_color = inquirer.themes.term.bright_cyan
201211
self.List.selection_color = inquirer.themes.term.bright_blue
202-
self.List.selection_cursor = "⚡"
203212
self.Checkbox.selection_color = inquirer.themes.term.bright_blue
204-
self.Checkbox.selection_cursor = "⚡"
205213
self.Checkbox.selected_icon = "✅"
206214
self.Checkbox.unselected_icon = "⬜"
207215

@@ -633,7 +641,7 @@ def check_for_toml_or_setup_file() -> str | None:
633641

634642
def install_github_actions(override_formatter_check: bool = False) -> None: # noqa: FBT001, FBT002
635643
try:
636-
config, config_file_path = parse_config_file(override_formatter_check=override_formatter_check)
644+
config, _config_file_path = parse_config_file(override_formatter_check=override_formatter_check)
637645

638646
ph("cli-github-actions-install-started")
639647
try:

codeflash/code_utils/code_extractor.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -528,16 +528,29 @@ def add_needed_imports_from_module(
528528

529529
try:
530530
for mod in gatherer.module_imports:
531+
# Skip __future__ imports as they cannot be imported directly
532+
# __future__ imports should only be imported with specific objects i.e from __future__ import annotations
533+
if mod == "__future__":
534+
continue
531535
if mod not in dotted_import_collector.imports:
532536
AddImportsVisitor.add_needed_import(dst_context, mod)
533537
RemoveImportsVisitor.remove_unused_import(dst_context, mod)
538+
aliased_objects = set()
539+
for mod, alias_pairs in gatherer.alias_mapping.items():
540+
for alias_pair in alias_pairs:
541+
if alias_pair[0] and alias_pair[1]: # Both name and alias exist
542+
aliased_objects.add(f"{mod}.{alias_pair[0]}")
543+
534544
for mod, obj_seq in gatherer.object_mapping.items():
535545
for obj in obj_seq:
536546
if (
537547
f"{mod}.{obj}" in helper_functions_fqn or dst_context.full_module_name == mod # avoid circular deps
538548
):
539549
continue # Skip adding imports for helper functions already in the context
540550

551+
if f"{mod}.{obj}" in aliased_objects:
552+
continue
553+
541554
# Handle star imports by resolving them to actual symbol names
542555
if obj == "*":
543556
resolved_symbols = resolve_star_import(mod, project_root)
@@ -559,6 +572,8 @@ def add_needed_imports_from_module(
559572
return dst_module_code
560573

561574
for mod, asname in gatherer.module_aliases.items():
575+
if not asname:
576+
continue
562577
if f"{mod}.{asname}" not in dotted_import_collector.imports:
563578
AddImportsVisitor.add_needed_import(dst_context, mod, asname=asname)
564579
RemoveImportsVisitor.remove_unused_import(dst_context, mod, asname=asname)
@@ -568,12 +583,16 @@ def add_needed_imports_from_module(
568583
if f"{mod}.{alias_pair[0]}" in helper_functions_fqn:
569584
continue
570585

586+
if not alias_pair[0] or not alias_pair[1]:
587+
continue
588+
571589
if f"{mod}.{alias_pair[1]}" not in dotted_import_collector.imports:
572590
AddImportsVisitor.add_needed_import(dst_context, mod, alias_pair[0], asname=alias_pair[1])
573591
RemoveImportsVisitor.remove_unused_import(dst_context, mod, alias_pair[0], asname=alias_pair[1])
574592

575593
try:
576-
transformed_module = AddImportsVisitor(dst_context).transform_module(parsed_dst_module)
594+
add_imports_visitor = AddImportsVisitor(dst_context)
595+
transformed_module = add_imports_visitor.transform_module(parsed_dst_module)
577596
transformed_module = RemoveImportsVisitor(dst_context).transform_module(transformed_module)
578597
return transformed_module.code.lstrip("\n")
579598
except Exception as e:

codeflash/code_utils/git_worktree_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,12 @@ def get_git_project_id() -> str:
3434

3535
def create_worktree_snapshot_commit(worktree_dir: Path, commit_message: str) -> None:
3636
repository = git.Repo(worktree_dir, search_parent_directories=True)
37+
with repository.config_writer() as cw:
38+
if not cw.has_option("user", "name"):
39+
cw.set_value("user", "name", "Codeflash Bot")
40+
if not cw.has_option("user", "email"):
41+
cw.set_value("user", "email", "bot@codeflash.ai")
42+
3743
repository.git.add(".")
3844
repository.git.commit("-m", commit_message, "--no-verify")
3945

codeflash/lsp/beta.py

Lines changed: 15 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,24 @@
11
from __future__ import annotations
22

3-
import contextlib
3+
import asyncio
44
import os
55
from dataclasses import dataclass
66
from pathlib import Path
77
from typing import TYPE_CHECKING, Optional
88

9-
import git
109
from pygls import uris
1110

1211
from codeflash.api.cfapi import get_codeflash_api_key, get_user_id
1312
from codeflash.cli_cmds.cli import process_pyproject_config
14-
from codeflash.cli_cmds.console import code_print
1513
from codeflash.code_utils.git_utils import git_root_dir
16-
from codeflash.code_utils.git_worktree_utils import create_diff_patch_from_worktree
1714
from codeflash.code_utils.shell_utils import save_api_key_to_rc
1815
from codeflash.discovery.functions_to_optimize import (
1916
filter_functions,
2017
get_functions_inside_a_commit,
2118
get_functions_within_git_diff,
2219
)
2320
from codeflash.either import is_successful
21+
from codeflash.lsp.features.perform_optimization import sync_perform_optimization
2422
from codeflash.lsp.server import CodeflashLanguageServer
2523

2624
if TYPE_CHECKING:
@@ -71,7 +69,6 @@ class OptimizableFunctionsInCommitParams:
7169
commit_hash: str
7270

7371

74-
# server = CodeflashLanguageServer("codeflash-language-server", "v1.0", protocol_cls=CodeflashLanguageServerProtocol)
7572
server = CodeflashLanguageServer("codeflash-language-server", "v1.0")
7673

7774

@@ -157,11 +154,13 @@ def _find_pyproject_toml(workspace_path: str) -> tuple[Path | None, bool]:
157154

158155
# should be called the first thing to initialize and validate the project
159156
@server.feature("initProject")
160-
def init_project(server: CodeflashLanguageServer, params: ValidateProjectParams) -> dict[str, str]: # noqa: PLR0911
157+
def init_project(server: CodeflashLanguageServer, params: ValidateProjectParams) -> dict[str, str]:
161158
from codeflash.cli_cmds.cmd_init import is_valid_pyproject_toml
162159

163-
pyproject_toml_path: Path | None = getattr(params, "config_file", None) or getattr(server.args, "config_file", None)
160+
# Always process args in the init project, the extension can call
161+
server.args_processed_before = False
164162

163+
pyproject_toml_path: Path | None = getattr(params, "config_file", None) or getattr(server.args, "config_file", None)
165164
if pyproject_toml_path is not None:
166165
# if there is a config file provided use it
167166
server.prepare_optimizer_arguments(pyproject_toml_path)
@@ -192,20 +191,12 @@ def init_project(server: CodeflashLanguageServer, params: ValidateProjectParams)
192191
}
193192

194193
server.show_message_log("Validating project...", "Info")
195-
config = is_valid_pyproject_toml(pyproject_toml_path)
194+
config, reason = is_valid_pyproject_toml(pyproject_toml_path)
196195
if config is None:
197196
server.show_message_log("pyproject.toml is not valid", "Error")
198-
return {"status": "error", "message": "not valid", "pyprojectPath": pyproject_toml_path}
197+
return {"status": "error", "message": f"reason: {reason}", "pyprojectPath": pyproject_toml_path}
199198

200199
args = process_args(server)
201-
repo = git.Repo(args.module_root, search_parent_directories=True)
202-
if repo.bare:
203-
return {"status": "error", "message": "Repository is in bare state"}
204-
205-
try:
206-
_ = repo.head.commit
207-
except Exception:
208-
return {"status": "error", "message": "Repository has no commits (unborn HEAD)"}
209200

210201
return {"status": "success", "moduleRoot": args.module_root, "pyprojectPath": pyproject_toml_path, "root": root}
211202

@@ -339,115 +330,15 @@ def initialize_function_optimization(
339330

340331

341332
@server.feature("performFunctionOptimization")
342-
@server.thread()
343-
def perform_function_optimization(
333+
async def perform_function_optimization(
344334
server: CodeflashLanguageServer, params: FunctionOptimizationParams
345335
) -> dict[str, str]:
336+
loop = asyncio.get_running_loop()
346337
try:
347-
server.show_message_log(f"Starting optimization for function: {params.functionName}", "Info")
348-
should_run_experiment, code_context, original_helper_code = server.current_optimization_init_result
349-
function_optimizer = server.optimizer.current_function_optimizer
350-
current_function = function_optimizer.function_to_optimize
351-
352-
code_print(
353-
code_context.read_writable_code.flat,
354-
file_name=current_function.file_path,
355-
function_name=current_function.function_name,
356-
)
357-
358-
optimizable_funcs = {current_function.file_path: [current_function]}
359-
360-
devnull_writer = open(os.devnull, "w") # noqa
361-
with contextlib.redirect_stdout(devnull_writer):
362-
function_to_tests, _num_discovered_tests = server.optimizer.discover_tests(optimizable_funcs)
363-
function_optimizer.function_to_tests = function_to_tests
364-
365-
test_setup_result = function_optimizer.generate_and_instrument_tests(
366-
code_context, should_run_experiment=should_run_experiment
367-
)
368-
if not is_successful(test_setup_result):
369-
return {"functionName": params.functionName, "status": "error", "message": test_setup_result.failure()}
370-
(
371-
generated_tests,
372-
function_to_concolic_tests,
373-
concolic_test_str,
374-
optimizations_set,
375-
generated_test_paths,
376-
generated_perf_test_paths,
377-
instrumented_unittests_created_for_function,
378-
original_conftest_content,
379-
) = test_setup_result.unwrap()
380-
381-
baseline_setup_result = function_optimizer.setup_and_establish_baseline(
382-
code_context=code_context,
383-
original_helper_code=original_helper_code,
384-
function_to_concolic_tests=function_to_concolic_tests,
385-
generated_test_paths=generated_test_paths,
386-
generated_perf_test_paths=generated_perf_test_paths,
387-
instrumented_unittests_created_for_function=instrumented_unittests_created_for_function,
388-
original_conftest_content=original_conftest_content,
389-
)
390-
391-
if not is_successful(baseline_setup_result):
392-
return {"functionName": params.functionName, "status": "error", "message": baseline_setup_result.failure()}
393-
394-
(
395-
function_to_optimize_qualified_name,
396-
function_to_all_tests,
397-
original_code_baseline,
398-
test_functions_to_remove,
399-
file_path_to_helper_classes,
400-
) = baseline_setup_result.unwrap()
401-
402-
best_optimization = function_optimizer.find_and_process_best_optimization(
403-
optimizations_set=optimizations_set,
404-
code_context=code_context,
405-
original_code_baseline=original_code_baseline,
406-
original_helper_code=original_helper_code,
407-
file_path_to_helper_classes=file_path_to_helper_classes,
408-
function_to_optimize_qualified_name=function_to_optimize_qualified_name,
409-
function_to_all_tests=function_to_all_tests,
410-
generated_tests=generated_tests,
411-
test_functions_to_remove=test_functions_to_remove,
412-
concolic_test_str=concolic_test_str,
413-
)
414-
415-
if not best_optimization:
416-
server.show_message_log(
417-
f"No best optimizations found for function {function_to_optimize_qualified_name}", "Warning"
418-
)
419-
return {
420-
"functionName": params.functionName,
421-
"status": "error",
422-
"message": f"No best optimizations found for function {function_to_optimize_qualified_name}",
423-
}
424-
425-
# generate a patch for the optimization
426-
relative_file_paths = [code_string.file_path for code_string in code_context.read_writable_code.code_strings]
427-
428-
speedup = original_code_baseline.runtime / best_optimization.runtime
429-
430-
patch_path = create_diff_patch_from_worktree(
431-
server.optimizer.current_worktree, relative_file_paths, function_to_optimize_qualified_name
432-
)
433-
434-
if not patch_path:
435-
return {
436-
"functionName": params.functionName,
437-
"status": "error",
438-
"message": "Failed to create a patch for optimization",
439-
}
440-
441-
server.show_message_log(f"Optimization completed for {params.functionName} with {speedup:.2f}x speedup", "Info")
442-
443-
return {
444-
"functionName": params.functionName,
445-
"status": "success",
446-
"message": "Optimization completed successfully",
447-
"extra": f"Speedup: {speedup:.2f}x faster",
448-
"patch_file": str(patch_path),
449-
"task_id": params.task_id,
450-
"explanation": best_optimization.explanation_v2,
451-
}
338+
result = await loop.run_in_executor(None, sync_perform_optimization, server, params)
339+
except asyncio.CancelledError:
340+
return {"status": "canceled", "message": "Task was canceled"}
341+
else:
342+
return result
452343
finally:
453344
server.cleanup_the_optimizer()

codeflash/lsp/features/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)