Skip to content

Commit 10b89fa

Browse files
mirror the args in worktree
1 parent bb35a60 commit 10b89fa

File tree

3 files changed

+45
-36
lines changed

3 files changed

+45
-36
lines changed

codeflash/cli_cmds/cli.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from codeflash.code_utils import env_utils
1212
from codeflash.code_utils.code_utils import exit_with_message
1313
from codeflash.code_utils.config_parser import parse_config_file
14-
from codeflash.code_utils.git_utils import git_root_dir
1514
from codeflash.lsp.helpers import is_LSP_enabled
1615
from codeflash.version import __version__ as version
1716

@@ -223,20 +222,18 @@ def process_pyproject_config(args: Namespace) -> Namespace:
223222
args.module_root = Path(args.module_root).resolve()
224223
# If module-root is "." then all imports are relatives to it.
225224
# in this case, the ".." becomes outside project scope, causing issues with un-importable paths
226-
args.project_root = project_root_from_module_root(args.module_root, pyproject_file_path, args.worktree)
225+
args.project_root = project_root_from_module_root(args.module_root, pyproject_file_path)
227226
args.tests_root = Path(args.tests_root).resolve()
228227
if args.benchmarks_root:
229228
args.benchmarks_root = Path(args.benchmarks_root).resolve()
230-
args.test_project_root = project_root_from_module_root(args.tests_root, pyproject_file_path, args.worktree)
229+
args.test_project_root = project_root_from_module_root(args.tests_root, pyproject_file_path)
231230
if is_LSP_enabled():
232231
args.all = None
233232
return args
234233
return handle_optimize_all_arg_parsing(args)
235234

236235

237-
def project_root_from_module_root(module_root: Path, pyproject_file_path: Path, in_worktree: bool = False) -> Path: # noqa: FBT001, FBT002
238-
if in_worktree:
239-
return git_root_dir()
236+
def project_root_from_module_root(module_root: Path, pyproject_file_path: Path) -> Path:
240237
if pyproject_file_path.parent == module_root:
241238
return module_root
242239
return module_root.parent.resolve()

codeflash/lsp/beta.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -115,24 +115,19 @@ def get_optimizable_functions(
115115
server: CodeflashLanguageServer, params: OptimizableFunctionsParams
116116
) -> dict[str, list[str]]:
117117
file_path = Path(uris.to_fs_path(params.textDocument.uri))
118-
server.show_message_log(f"Getting optimizable functions for: {file_path}", "Info")
119118
if not server.optimizer:
120119
return {"status": "error", "message": "optimizer not initialized"}
121120

122121
server.optimizer.args.file = file_path
123122
server.optimizer.args.function = None # Always get ALL functions, not just one
124123
server.optimizer.args.previous_checkpoint_functions = False
125124

126-
server.show_message_log(f"Calling get_optimizable_functions for {server.optimizer.args.file}...", "Info")
127125
optimizable_funcs, _, _ = server.optimizer.get_optimizable_functions()
128126

129127
path_to_qualified_names = {}
130128
for functions in optimizable_funcs.values():
131129
path_to_qualified_names[file_path] = [func.qualified_name for func in functions]
132130

133-
server.show_message_log(
134-
f"Found {len(path_to_qualified_names)} files with functions: {path_to_qualified_names}", "Info"
135-
)
136131
return path_to_qualified_names
137132

138133

@@ -177,7 +172,7 @@ def init_project(server: CodeflashLanguageServer, params: ValidateProjectParams)
177172
else:
178173
return {"status": "error", "message": "No pyproject.toml found in workspace."}
179174

180-
# since we are using worktrees, optimization diffs are generated with respect to the root of the repo, also the args.project_root is set to the root of the repo when creating a worktree
175+
# since we are using worktrees, optimization diffs are generated with respect to the root of the repo.
181176
root = str(git_root_dir())
182177

183178
if getattr(params, "skip_validation", False):

codeflash/optimization/optimizer.py

Lines changed: 41 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from codeflash.code_utils import env_utils
1616
from codeflash.code_utils.code_utils import cleanup_paths, get_run_tmp_file
1717
from codeflash.code_utils.env_utils import get_pr_number, is_pr_draft
18-
from codeflash.code_utils.git_utils import check_running_in_git_repo
18+
from codeflash.code_utils.git_utils import check_running_in_git_repo, git_root_dir
1919
from codeflash.code_utils.git_worktree_utils import (
2020
create_detached_worktree,
2121
create_diff_patch_from_worktree,
@@ -447,35 +447,52 @@ def worktree_mode(self) -> None:
447447
Path(self.args.tests_root).mkdir(parents=True, exist_ok=True)
448448

449449
def mutate_args_for_worktree_mode(self, worktree_dir: Path) -> None:
450-
saved_args = copy.deepcopy(self.args)
451-
saved_test_cfg = copy.deepcopy(self.test_cfg)
452-
self.original_args_and_test_cfg = (saved_args, saved_test_cfg)
453-
454-
project_root = self.args.project_root
455-
module_root = self.args.module_root
456-
relative_module_root = module_root.relative_to(project_root)
457-
relative_optimized_file = self.args.file.relative_to(project_root) if self.args.file else None
458-
relative_tests_root = self.test_cfg.tests_root.relative_to(project_root)
459-
relative_benchmarks_root = (
460-
self.args.benchmarks_root.relative_to(project_root) if self.args.benchmarks_root else None
450+
original_args = copy.deepcopy(self.args)
451+
original_test_cfg = copy.deepcopy(self.test_cfg)
452+
self.original_args_and_test_cfg = (original_args, original_test_cfg)
453+
454+
original_module_root = original_args.module_root
455+
original_git_root = git_root_dir().as_posix()
456+
457+
# mutate project_root
458+
relative_project_root = original_args.project_root.relative_to(original_git_root).as_posix()
459+
# this will be the same as the original project root but in the worktree
460+
new_project_root = worktree_dir / relative_project_root
461+
self.args.project_root = new_project_root
462+
self.test_cfg.project_root_path = new_project_root
463+
464+
# mutate module_root
465+
relative_module_root = original_module_root.relative_to(original_git_root).as_posix()
466+
self.args.module_root = worktree_dir / relative_module_root
467+
468+
# mute target file
469+
relative_optimized_file = (
470+
original_args.file.relative_to(original_git_root).as_posix() if original_args.file else None
461471
)
472+
if relative_optimized_file is not None:
473+
self.args.file = worktree_dir / relative_optimized_file
462474

463-
self.args.module_root = worktree_dir / relative_module_root
464-
self.args.project_root = worktree_dir
465-
self.args.test_project_root = worktree_dir
466-
self.args.tests_root = worktree_dir / relative_tests_root
467-
if relative_benchmarks_root:
468-
self.args.benchmarks_root = worktree_dir / relative_benchmarks_root
475+
# mutate tests root
476+
relative_tests_root = original_test_cfg.tests_root.relative_to(original_git_root).as_posix()
477+
new_tests_root = worktree_dir / relative_tests_root
478+
self.args.tests_root = new_tests_root
479+
self.test_cfg.tests_root = new_tests_root
469480

470-
self.test_cfg.project_root_path = worktree_dir
471-
self.test_cfg.tests_project_rootdir = worktree_dir
472-
self.test_cfg.tests_root = worktree_dir / relative_tests_root
481+
# mutate tests project root
482+
relative_tests_project_root = original_args.test_project_root.relative_to(original_git_root).as_posix()
483+
self.args.test_project_root = worktree_dir / relative_tests_project_root
484+
self.test_cfg.tests_project_rootdir = worktree_dir / relative_tests_project_root
485+
486+
# mutate benchmarks root
487+
relative_benchmarks_root = (
488+
original_args.benchmarks_root.relative_to(original_git_root).as_posix()
489+
if original_args.benchmarks_root
490+
else None
491+
)
473492
if relative_benchmarks_root:
493+
self.args.benchmarks_root = worktree_dir / relative_benchmarks_root
474494
self.test_cfg.benchmark_tests_root = worktree_dir / relative_benchmarks_root
475495

476-
if relative_optimized_file is not None:
477-
self.args.file = worktree_dir / relative_optimized_file
478-
479496

480497
def run_with_args(args: Namespace) -> None:
481498
optimizer = None

0 commit comments

Comments
 (0)