Skip to content

Commit aa398f7

Browse files
Merge pull request #792 from codeflash-ai/fix/exactly-mimic-args-in-worktree
mirror the args and test_cfg paths for worktree
2 parents bb35a60 + 6ed7a6c commit aa398f7

File tree

7 files changed

+127
-41
lines changed

7 files changed

+127
-41
lines changed
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
[tool.codeflash]
2+
# All paths are relative to this pyproject.toml's directory.
3+
module-root = "src/app"
4+
tests-root = "src/tests"
5+
test-framework = "pytest"
6+
ignore-paths = []
7+
disable-telemetry = true
8+
formatter-cmds = ["disabled"]
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
def sorter(arr):
2+
print("codeflash stdout: Sorting list")
3+
for i in range(len(arr)):
4+
for j in range(len(arr) - 1):
5+
if arr[j] > arr[j + 1]:
6+
temp = arr[j]
7+
arr[j] = arr[j + 1]
8+
arr[j + 1] = temp
9+
print(f"result: {arr}")
10+
return arr

code_to_optimize/code_directories/nested_module_root/src/tests/.gitkeep

Whitespace-only changes.

codeflash/cli_cmds/cli.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from codeflash.code_utils import env_utils
1212
from codeflash.code_utils.code_utils import exit_with_message
1313
from codeflash.code_utils.config_parser import parse_config_file
14-
from codeflash.code_utils.git_utils import git_root_dir
1514
from codeflash.lsp.helpers import is_LSP_enabled
1615
from codeflash.version import __version__ as version
1716

@@ -223,20 +222,18 @@ def process_pyproject_config(args: Namespace) -> Namespace:
223222
args.module_root = Path(args.module_root).resolve()
224223
# If module-root is "." then all imports are relatives to it.
225224
# in this case, the ".." becomes outside project scope, causing issues with un-importable paths
226-
args.project_root = project_root_from_module_root(args.module_root, pyproject_file_path, args.worktree)
225+
args.project_root = project_root_from_module_root(args.module_root, pyproject_file_path)
227226
args.tests_root = Path(args.tests_root).resolve()
228227
if args.benchmarks_root:
229228
args.benchmarks_root = Path(args.benchmarks_root).resolve()
230-
args.test_project_root = project_root_from_module_root(args.tests_root, pyproject_file_path, args.worktree)
229+
args.test_project_root = project_root_from_module_root(args.tests_root, pyproject_file_path)
231230
if is_LSP_enabled():
232231
args.all = None
233232
return args
234233
return handle_optimize_all_arg_parsing(args)
235234

236235

237-
def project_root_from_module_root(module_root: Path, pyproject_file_path: Path, in_worktree: bool = False) -> Path: # noqa: FBT001, FBT002
238-
if in_worktree:
239-
return git_root_dir()
236+
def project_root_from_module_root(module_root: Path, pyproject_file_path: Path) -> Path:
240237
if pyproject_file_path.parent == module_root:
241238
return module_root
242239
return module_root.parent.resolve()

codeflash/lsp/beta.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -115,24 +115,19 @@ def get_optimizable_functions(
115115
server: CodeflashLanguageServer, params: OptimizableFunctionsParams
116116
) -> dict[str, list[str]]:
117117
file_path = Path(uris.to_fs_path(params.textDocument.uri))
118-
server.show_message_log(f"Getting optimizable functions for: {file_path}", "Info")
119118
if not server.optimizer:
120119
return {"status": "error", "message": "optimizer not initialized"}
121120

122121
server.optimizer.args.file = file_path
123122
server.optimizer.args.function = None # Always get ALL functions, not just one
124123
server.optimizer.args.previous_checkpoint_functions = False
125124

126-
server.show_message_log(f"Calling get_optimizable_functions for {server.optimizer.args.file}...", "Info")
127125
optimizable_funcs, _, _ = server.optimizer.get_optimizable_functions()
128126

129127
path_to_qualified_names = {}
130128
for functions in optimizable_funcs.values():
131129
path_to_qualified_names[file_path] = [func.qualified_name for func in functions]
132130

133-
server.show_message_log(
134-
f"Found {len(path_to_qualified_names)} files with functions: {path_to_qualified_names}", "Info"
135-
)
136131
return path_to_qualified_names
137132

138133

@@ -177,7 +172,7 @@ def init_project(server: CodeflashLanguageServer, params: ValidateProjectParams)
177172
else:
178173
return {"status": "error", "message": "No pyproject.toml found in workspace."}
179174

180-
# since we are using worktrees, optimization diffs are generated with respect to the root of the repo, also the args.project_root is set to the root of the repo when creating a worktree
175+
# since we are using worktrees, optimization diffs are generated with respect to the root of the repo.
181176
root = str(git_root_dir())
182177

183178
if getattr(params, "skip_validation", False):

codeflash/optimization/optimizer.py

Lines changed: 40 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from codeflash.code_utils import env_utils
1616
from codeflash.code_utils.code_utils import cleanup_paths, get_run_tmp_file
1717
from codeflash.code_utils.env_utils import get_pr_number, is_pr_draft
18-
from codeflash.code_utils.git_utils import check_running_in_git_repo
18+
from codeflash.code_utils.git_utils import check_running_in_git_repo, git_root_dir
1919
from codeflash.code_utils.git_worktree_utils import (
2020
create_detached_worktree,
2121
create_diff_patch_from_worktree,
@@ -442,39 +442,50 @@ def worktree_mode(self) -> None:
442442
logger.warning("Failed to create worktree. Skipping optimization.")
443443
return
444444
self.current_worktree = worktree_dir
445-
self.mutate_args_for_worktree_mode(worktree_dir)
445+
self.mirror_paths_for_worktree_mode(worktree_dir)
446446
# make sure the tests dir is created in the worktree, this can happen if the original tests dir is empty
447447
Path(self.args.tests_root).mkdir(parents=True, exist_ok=True)
448448

449-
def mutate_args_for_worktree_mode(self, worktree_dir: Path) -> None:
450-
saved_args = copy.deepcopy(self.args)
451-
saved_test_cfg = copy.deepcopy(self.test_cfg)
452-
self.original_args_and_test_cfg = (saved_args, saved_test_cfg)
453-
454-
project_root = self.args.project_root
455-
module_root = self.args.module_root
456-
relative_module_root = module_root.relative_to(project_root)
457-
relative_optimized_file = self.args.file.relative_to(project_root) if self.args.file else None
458-
relative_tests_root = self.test_cfg.tests_root.relative_to(project_root)
459-
relative_benchmarks_root = (
460-
self.args.benchmarks_root.relative_to(project_root) if self.args.benchmarks_root else None
449+
def mirror_paths_for_worktree_mode(self, worktree_dir: Path) -> None:
450+
original_args = copy.deepcopy(self.args)
451+
original_test_cfg = copy.deepcopy(self.test_cfg)
452+
self.original_args_and_test_cfg = (original_args, original_test_cfg)
453+
454+
original_git_root = git_root_dir()
455+
456+
# mirror project_root
457+
self.args.project_root = mirror_path(self.args.project_root, original_git_root, worktree_dir)
458+
self.test_cfg.project_root_path = mirror_path(self.test_cfg.project_root_path, original_git_root, worktree_dir)
459+
460+
# mirror module_root
461+
self.args.module_root = mirror_path(self.args.module_root, original_git_root, worktree_dir)
462+
463+
# mirror target file
464+
if self.args.file:
465+
self.args.file = mirror_path(self.args.file, original_git_root, worktree_dir)
466+
467+
# mirror tests root
468+
self.args.tests_root = mirror_path(self.args.tests_root, original_git_root, worktree_dir)
469+
self.test_cfg.tests_root = mirror_path(self.test_cfg.tests_root, original_git_root, worktree_dir)
470+
471+
# mirror tests project root
472+
self.args.test_project_root = mirror_path(self.args.test_project_root, original_git_root, worktree_dir)
473+
self.test_cfg.tests_project_rootdir = mirror_path(
474+
self.test_cfg.tests_project_rootdir, original_git_root, worktree_dir
461475
)
462476

463-
self.args.module_root = worktree_dir / relative_module_root
464-
self.args.project_root = worktree_dir
465-
self.args.test_project_root = worktree_dir
466-
self.args.tests_root = worktree_dir / relative_tests_root
467-
if relative_benchmarks_root:
468-
self.args.benchmarks_root = worktree_dir / relative_benchmarks_root
469-
470-
self.test_cfg.project_root_path = worktree_dir
471-
self.test_cfg.tests_project_rootdir = worktree_dir
472-
self.test_cfg.tests_root = worktree_dir / relative_tests_root
473-
if relative_benchmarks_root:
474-
self.test_cfg.benchmark_tests_root = worktree_dir / relative_benchmarks_root
475-
476-
if relative_optimized_file is not None:
477-
self.args.file = worktree_dir / relative_optimized_file
477+
# mirror benchmarks root paths
478+
if self.args.benchmarks_root:
479+
self.args.benchmarks_root = mirror_path(self.args.benchmarks_root, original_git_root, worktree_dir)
480+
if self.test_cfg.benchmark_tests_root:
481+
self.test_cfg.benchmark_tests_root = mirror_path(
482+
self.test_cfg.benchmark_tests_root, original_git_root, worktree_dir
483+
)
484+
485+
486+
def mirror_path(path: Path, src_root: Path, dest_root: Path) -> Path:
487+
relative_path = path.relative_to(src_root)
488+
return dest_root / relative_path
478489

479490

480491
def run_with_args(args: Namespace) -> None:

tests/test_worktree.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
from argparse import Namespace
2+
from pathlib import Path
3+
4+
import pytest
5+
from codeflash.cli_cmds.cli import process_pyproject_config
6+
from codeflash.optimization.optimizer import Optimizer
7+
8+
9+
def test_mirror_paths_for_worktree_mode(monkeypatch: pytest.MonkeyPatch):
10+
repo_root = Path(__file__).resolve().parent.parent
11+
project_root = repo_root / "code_to_optimize" / "code_directories" / "nested_module_root"
12+
13+
monkeypatch.setattr("codeflash.optimization.optimizer.git_root_dir", lambda: project_root)
14+
15+
args = Namespace()
16+
args.benchmark = False
17+
args.benchmarks_root = None
18+
19+
args.config_file = project_root / "pyproject.toml"
20+
args.file = project_root / "src" / "app" / "main.py"
21+
args.worktree = True
22+
23+
new_args = process_pyproject_config(args)
24+
25+
optimizer = Optimizer(new_args)
26+
27+
worktree_dir = repo_root / "worktree"
28+
optimizer.mirror_paths_for_worktree_mode(worktree_dir)
29+
30+
assert optimizer.args.project_root == worktree_dir / "src"
31+
assert optimizer.args.test_project_root == worktree_dir / "src"
32+
assert optimizer.args.module_root == worktree_dir / "src" / "app"
33+
assert optimizer.args.tests_root == worktree_dir / "src" / "tests"
34+
assert optimizer.args.file == worktree_dir / "src" / "app" / "main.py"
35+
36+
assert optimizer.test_cfg.tests_root == worktree_dir / "src" / "tests"
37+
assert optimizer.test_cfg.project_root_path == worktree_dir / "src" # same as project_root
38+
assert optimizer.test_cfg.tests_project_rootdir == worktree_dir / "src" # same as test_project_root
39+
40+
# test on our repo
41+
monkeypatch.setattr("codeflash.optimization.optimizer.git_root_dir", lambda: repo_root)
42+
args = Namespace()
43+
args.benchmark = False
44+
args.benchmarks_root = None
45+
46+
args.config_file = repo_root / "pyproject.toml"
47+
args.file = repo_root / "codeflash/optimization/optimizer.py"
48+
args.worktree = True
49+
50+
new_args = process_pyproject_config(args)
51+
52+
optimizer = Optimizer(new_args)
53+
54+
worktree_dir = repo_root / "worktree"
55+
optimizer.mirror_paths_for_worktree_mode(worktree_dir)
56+
57+
assert optimizer.args.project_root == worktree_dir
58+
assert optimizer.args.test_project_root == worktree_dir
59+
assert optimizer.args.module_root == worktree_dir / "codeflash"
60+
assert optimizer.args.tests_root == worktree_dir / "tests"
61+
assert optimizer.args.file == worktree_dir / "codeflash/optimization/optimizer.py"
62+
63+
assert optimizer.test_cfg.tests_root == worktree_dir / "tests"
64+
assert optimizer.test_cfg.project_root_path == worktree_dir # same as project_root
65+
assert optimizer.test_cfg.tests_project_rootdir == worktree_dir # same as test_project_root

0 commit comments

Comments
 (0)