Skip to content

Commit bc898cf

Browse files
authored
Merge branch 'main' into init/install-vscode-extension
2 parents dc22b4a + 3dcf7a3 commit bc898cf

19 files changed

+1644
-943
lines changed

codeflash/code_utils/code_extractor.py

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -335,12 +335,12 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
335335
return updated_node
336336

337337

338-
def extract_global_statements(source_code: str) -> list[cst.SimpleStatementLine]:
338+
def extract_global_statements(source_code: str) -> tuple[cst.Module, list[cst.SimpleStatementLine]]:
339339
"""Extract global statements from source code."""
340340
module = cst.parse_module(source_code)
341341
collector = GlobalStatementCollector()
342342
module.visit(collector)
343-
return collector.global_statements
343+
return module, collector.global_statements
344344

345345

346346
def find_last_import_line(target_code: str) -> int:
@@ -373,30 +373,41 @@ def delete___future___aliased_imports(module_code: str) -> str:
373373

374374

375375
def add_global_assignments(src_module_code: str, dst_module_code: str) -> str:
376-
non_assignment_global_statements = extract_global_statements(src_module_code)
376+
src_module, new_added_global_statements = extract_global_statements(src_module_code)
377+
dst_module, existing_global_statements = extract_global_statements(dst_module_code)
377378

378-
# Find the last import line in target
379-
last_import_line = find_last_import_line(dst_module_code)
380-
381-
# Parse the target code
382-
target_module = cst.parse_module(dst_module_code)
383-
384-
# Create transformer to insert non_assignment_global_statements
385-
transformer = ImportInserter(non_assignment_global_statements, last_import_line)
386-
#
387-
# # Apply transformation
388-
modified_module = target_module.visit(transformer)
389-
dst_module_code = modified_module.code
390-
391-
# Parse the code
392-
original_module = cst.parse_module(dst_module_code)
393-
new_module = cst.parse_module(src_module_code)
379+
unique_global_statements = []
380+
for stmt in new_added_global_statements:
381+
if any(
382+
stmt is existing_stmt or stmt.deep_equals(existing_stmt) for existing_stmt in existing_global_statements
383+
):
384+
continue
385+
unique_global_statements.append(stmt)
386+
387+
mod_dst_code = dst_module_code
388+
# Insert unique global statements if any
389+
if unique_global_statements:
390+
last_import_line = find_last_import_line(dst_module_code)
391+
# Reuse already-parsed dst_module
392+
transformer = ImportInserter(unique_global_statements, last_import_line)
393+
# Use visit inplace, don't parse again
394+
modified_module = dst_module.visit(transformer)
395+
mod_dst_code = modified_module.code
396+
# Parse the code after insertion
397+
original_module = cst.parse_module(mod_dst_code)
398+
else:
399+
# No new statements to insert, reuse already-parsed dst_module
400+
original_module = dst_module
394401

402+
# Parse the src_module_code once only (already done above: src_module)
395403
# Collect assignments from the new file
396404
new_collector = GlobalAssignmentCollector()
397-
new_module.visit(new_collector)
405+
src_module.visit(new_collector)
406+
# Only create transformer if there are assignments to insert/transform
407+
if not new_collector.assignments: # nothing to transform
408+
return mod_dst_code
398409

399-
# Transform the original file
410+
# Transform the original destination module
400411
transformer = GlobalAssignmentTransformer(new_collector.assignments, new_collector.assignment_order)
401412
transformed_module = original_module.visit(transformer)
402413

codeflash/code_utils/code_replacer.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,11 +412,17 @@ def replace_function_definitions_in_module(
412412
module_abspath: Path,
413413
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]],
414414
project_root_path: Path,
415+
should_add_global_assignments: bool = True, # noqa: FBT001, FBT002
415416
) -> bool:
416417
source_code: str = module_abspath.read_text(encoding="utf8")
417418
code_to_apply = get_optimized_code_for_module(module_abspath.relative_to(project_root_path), optimized_code)
419+
418420
new_code: str = replace_functions_and_add_imports(
419-
add_global_assignments(code_to_apply, source_code),
421+
# adding the global assignments before replacing the code, not after
422+
# becuase of an "edge case" where the optimized code intoduced a new import and a global assignment using that import
423+
# and that import wasn't used before, so it was ignored when calling AddImportsVisitor.add_needed_import inside replace_functions_and_add_imports (because the global assignment wasn't added yet)
424+
# this was added at https://github.com/codeflash-ai/codeflash/pull/448
425+
add_global_assignments(code_to_apply, source_code) if should_add_global_assignments else source_code,
420426
function_names,
421427
code_to_apply,
422428
module_abspath,

codeflash/code_utils/git_utils.py

Lines changed: 8 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,24 @@
1616
from unidiff import PatchSet
1717

1818
from codeflash.cli_cmds.console import logger
19-
from codeflash.code_utils.compat import codeflash_cache_dir
2019
from codeflash.code_utils.config_consts import N_CANDIDATES
2120

2221
if TYPE_CHECKING:
2322
from git import Repo
2423

2524

26-
def get_git_diff(repo_directory: Path | None = None, *, uncommitted_changes: bool = False) -> dict[str, list[int]]:
25+
def get_git_diff(
26+
repo_directory: Path | None = None, *, only_this_commit: Optional[str] = None, uncommitted_changes: bool = False
27+
) -> dict[str, list[int]]:
2728
if repo_directory is None:
2829
repo_directory = Path.cwd()
2930
repository = git.Repo(repo_directory, search_parent_directories=True)
3031
commit = repository.head.commit
31-
if uncommitted_changes:
32+
if only_this_commit:
33+
uni_diff_text = repository.git.diff(
34+
only_this_commit + "^1", only_this_commit, ignore_blank_lines=True, ignore_space_at_eol=True
35+
)
36+
elif uncommitted_changes:
3237
uni_diff_text = repository.git.diff(None, "HEAD", ignore_blank_lines=True, ignore_space_at_eol=True)
3338
else:
3439
uni_diff_text = repository.git.diff(
@@ -193,84 +198,3 @@ def get_last_commit_author_if_pr_exists(repo: Repo | None = None) -> str | None:
193198
return None
194199
else:
195200
return last_commit.author.name
196-
197-
198-
worktree_dirs = codeflash_cache_dir / "worktrees"
199-
patches_dir = codeflash_cache_dir / "patches"
200-
201-
202-
def create_worktree_snapshot_commit(worktree_dir: Path, commit_message: str) -> None:
203-
repository = git.Repo(worktree_dir, search_parent_directories=True)
204-
repository.git.add(".")
205-
repository.git.commit("-m", commit_message, "--no-verify")
206-
207-
208-
def create_detached_worktree(module_root: Path) -> Optional[Path]:
209-
if not check_running_in_git_repo(module_root):
210-
logger.warning("Module is not in a git repository. Skipping worktree creation.")
211-
return None
212-
git_root = git_root_dir()
213-
current_time_str = time.strftime("%Y%m%d-%H%M%S")
214-
worktree_dir = worktree_dirs / f"{git_root.name}-{current_time_str}"
215-
216-
repository = git.Repo(git_root, search_parent_directories=True)
217-
218-
repository.git.worktree("add", "-d", str(worktree_dir))
219-
220-
# Get uncommitted diff from the original repo
221-
repository.git.add("-N", ".") # add the index for untracked files to be included in the diff
222-
exclude_binary_files = [":!*.pyc", ":!*.pyo", ":!*.pyd", ":!*.so", ":!*.dll", ":!*.whl", ":!*.egg", ":!*.egg-info", ":!*.pyz", ":!*.pkl", ":!*.pickle", ":!*.joblib", ":!*.npy", ":!*.npz", ":!*.h5", ":!*.hdf5", ":!*.pth", ":!*.pt", ":!*.pb", ":!*.onnx", ":!*.db", ":!*.sqlite", ":!*.sqlite3", ":!*.feather", ":!*.parquet", ":!*.jpg", ":!*.jpeg", ":!*.png", ":!*.gif", ":!*.bmp", ":!*.tiff", ":!*.webp", ":!*.wav", ":!*.mp3", ":!*.ogg", ":!*.flac", ":!*.mp4", ":!*.avi", ":!*.mov", ":!*.mkv", ":!*.pdf", ":!*.doc", ":!*.docx", ":!*.xls", ":!*.xlsx", ":!*.ppt", ":!*.pptx", ":!*.zip", ":!*.rar", ":!*.tar", ":!*.tar.gz", ":!*.tgz", ":!*.bz2", ":!*.xz"] # fmt: off
223-
uni_diff_text = repository.git.diff(
224-
None, "HEAD", "--", *exclude_binary_files, ignore_blank_lines=True, ignore_space_at_eol=True
225-
)
226-
227-
if not uni_diff_text.strip():
228-
logger.info("No uncommitted changes to copy to worktree.")
229-
return worktree_dir
230-
231-
# Write the diff to a temporary file
232-
with tempfile.NamedTemporaryFile(mode="w", suffix=".codeflash.patch", delete=False) as tmp_patch_file:
233-
tmp_patch_file.write(uni_diff_text + "\n") # the new line here is a must otherwise the last hunk won't be valid
234-
tmp_patch_file.flush()
235-
236-
patch_path = Path(tmp_patch_file.name).resolve()
237-
238-
# Apply the patch inside the worktree
239-
try:
240-
subprocess.run(
241-
["git", "apply", "--ignore-space-change", "--ignore-whitespace", "--whitespace=nowarn", patch_path],
242-
cwd=worktree_dir,
243-
check=True,
244-
)
245-
create_worktree_snapshot_commit(worktree_dir, "Initial Snapshot")
246-
except subprocess.CalledProcessError as e:
247-
logger.error(f"Failed to apply patch to worktree: {e}")
248-
249-
return worktree_dir
250-
251-
252-
def remove_worktree(worktree_dir: Path) -> None:
253-
try:
254-
repository = git.Repo(worktree_dir, search_parent_directories=True)
255-
repository.git.worktree("remove", "--force", worktree_dir)
256-
except Exception:
257-
logger.exception(f"Failed to remove worktree: {worktree_dir}")
258-
259-
260-
def create_diff_patch_from_worktree(worktree_dir: Path, files: list[str], fto_name: str) -> Path:
261-
repository = git.Repo(worktree_dir, search_parent_directories=True)
262-
uni_diff_text = repository.git.diff(None, "HEAD", *files, ignore_blank_lines=True, ignore_space_at_eol=True)
263-
264-
if not uni_diff_text:
265-
logger.warning("No changes found in worktree.")
266-
return None
267-
268-
if not uni_diff_text.endswith("\n"):
269-
uni_diff_text += "\n"
270-
271-
# write to patches_dir
272-
patches_dir.mkdir(parents=True, exist_ok=True)
273-
patch_path = patches_dir / f"{worktree_dir.name}.{fto_name}.patch"
274-
with patch_path.open("w", encoding="utf8") as f:
275-
f.write(uni_diff_text)
276-
return patch_path
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
from __future__ import annotations
2+
3+
import json
4+
import subprocess
5+
import tempfile
6+
import time
7+
from functools import lru_cache
8+
from pathlib import Path
9+
from typing import TYPE_CHECKING, Optional
10+
11+
import git
12+
from filelock import FileLock
13+
14+
from codeflash.cli_cmds.console import logger
15+
from codeflash.code_utils.compat import codeflash_cache_dir
16+
from codeflash.code_utils.git_utils import check_running_in_git_repo, git_root_dir
17+
18+
if TYPE_CHECKING:
19+
from typing import Any
20+
21+
from git import Repo
22+
23+
24+
worktree_dirs = codeflash_cache_dir / "worktrees"
25+
patches_dir = codeflash_cache_dir / "patches"
26+
27+
if TYPE_CHECKING:
28+
from git import Repo
29+
30+
31+
@lru_cache(maxsize=1)
32+
def get_git_project_id() -> str:
33+
"""Return the first commit sha of the repo."""
34+
repo: Repo = git.Repo(search_parent_directories=True)
35+
root_commits = list(repo.iter_commits(rev="HEAD", max_parents=0))
36+
return root_commits[0].hexsha
37+
38+
39+
def create_worktree_snapshot_commit(worktree_dir: Path, commit_message: str) -> None:
40+
repository = git.Repo(worktree_dir, search_parent_directories=True)
41+
repository.git.add(".")
42+
repository.git.commit("-m", commit_message, "--no-verify")
43+
44+
45+
def create_detached_worktree(module_root: Path) -> Optional[Path]:
46+
if not check_running_in_git_repo(module_root):
47+
logger.warning("Module is not in a git repository. Skipping worktree creation.")
48+
return None
49+
git_root = git_root_dir()
50+
current_time_str = time.strftime("%Y%m%d-%H%M%S")
51+
worktree_dir = worktree_dirs / f"{git_root.name}-{current_time_str}"
52+
53+
repository = git.Repo(git_root, search_parent_directories=True)
54+
55+
repository.git.worktree("add", "-d", str(worktree_dir))
56+
57+
# Get uncommitted diff from the original repo
58+
repository.git.add("-N", ".") # add the index for untracked files to be included in the diff
59+
exclude_binary_files = [":!*.pyc", ":!*.pyo", ":!*.pyd", ":!*.so", ":!*.dll", ":!*.whl", ":!*.egg", ":!*.egg-info", ":!*.pyz", ":!*.pkl", ":!*.pickle", ":!*.joblib", ":!*.npy", ":!*.npz", ":!*.h5", ":!*.hdf5", ":!*.pth", ":!*.pt", ":!*.pb", ":!*.onnx", ":!*.db", ":!*.sqlite", ":!*.sqlite3", ":!*.feather", ":!*.parquet", ":!*.jpg", ":!*.jpeg", ":!*.png", ":!*.gif", ":!*.bmp", ":!*.tiff", ":!*.webp", ":!*.wav", ":!*.mp3", ":!*.ogg", ":!*.flac", ":!*.mp4", ":!*.avi", ":!*.mov", ":!*.mkv", ":!*.pdf", ":!*.doc", ":!*.docx", ":!*.xls", ":!*.xlsx", ":!*.ppt", ":!*.pptx", ":!*.zip", ":!*.rar", ":!*.tar", ":!*.tar.gz", ":!*.tgz", ":!*.bz2", ":!*.xz"] # fmt: off
60+
uni_diff_text = repository.git.diff(
61+
None, "HEAD", "--", *exclude_binary_files, ignore_blank_lines=True, ignore_space_at_eol=True
62+
)
63+
64+
if not uni_diff_text.strip():
65+
logger.info("No uncommitted changes to copy to worktree.")
66+
return worktree_dir
67+
68+
# Write the diff to a temporary file
69+
with tempfile.NamedTemporaryFile(mode="w", suffix=".codeflash.patch", delete=False) as tmp_patch_file:
70+
tmp_patch_file.write(uni_diff_text + "\n") # the new line here is a must otherwise the last hunk won't be valid
71+
tmp_patch_file.flush()
72+
73+
patch_path = Path(tmp_patch_file.name).resolve()
74+
75+
# Apply the patch inside the worktree
76+
try:
77+
subprocess.run(
78+
["git", "apply", "--ignore-space-change", "--ignore-whitespace", "--whitespace=nowarn", patch_path],
79+
cwd=worktree_dir,
80+
check=True,
81+
)
82+
create_worktree_snapshot_commit(worktree_dir, "Initial Snapshot")
83+
except subprocess.CalledProcessError as e:
84+
logger.error(f"Failed to apply patch to worktree: {e}")
85+
86+
return worktree_dir
87+
88+
89+
def remove_worktree(worktree_dir: Path) -> None:
90+
try:
91+
repository = git.Repo(worktree_dir, search_parent_directories=True)
92+
repository.git.worktree("remove", "--force", worktree_dir)
93+
except Exception:
94+
logger.exception(f"Failed to remove worktree: {worktree_dir}")
95+
96+
97+
@lru_cache(maxsize=1)
98+
def get_patches_dir_for_project() -> Path:
99+
project_id = get_git_project_id() or ""
100+
return Path(patches_dir / project_id)
101+
102+
103+
def get_patches_metadata() -> dict[str, Any]:
104+
project_patches_dir = get_patches_dir_for_project()
105+
meta_file = project_patches_dir / "metadata.json"
106+
if meta_file.exists():
107+
with meta_file.open("r", encoding="utf-8") as f:
108+
return json.load(f)
109+
return {"id": get_git_project_id() or "", "patches": []}
110+
111+
112+
def save_patches_metadata(patch_metadata: dict) -> dict:
113+
project_patches_dir = get_patches_dir_for_project()
114+
meta_file = project_patches_dir / "metadata.json"
115+
lock_file = project_patches_dir / "metadata.json.lock"
116+
117+
# we are not supporting multiple concurrent optimizations within the same process, but keep that in case we decide to do so in the future.
118+
with FileLock(lock_file, timeout=10):
119+
metadata = get_patches_metadata()
120+
121+
patch_metadata["id"] = time.strftime("%Y%m%d-%H%M%S")
122+
metadata["patches"].append(patch_metadata)
123+
124+
meta_file.write_text(json.dumps(metadata, indent=2))
125+
126+
return patch_metadata
127+
128+
129+
def overwrite_patch_metadata(patches: list[dict]) -> bool:
130+
project_patches_dir = get_patches_dir_for_project()
131+
meta_file = project_patches_dir / "metadata.json"
132+
lock_file = project_patches_dir / "metadata.json.lock"
133+
134+
with FileLock(lock_file, timeout=10):
135+
metadata = get_patches_metadata()
136+
metadata["patches"] = patches
137+
meta_file.write_text(json.dumps(metadata, indent=2))
138+
return True
139+
140+
141+
def create_diff_patch_from_worktree(
142+
worktree_dir: Path,
143+
files: list[str],
144+
fto_name: Optional[str] = None,
145+
metadata_input: Optional[dict[str, Any]] = None,
146+
) -> dict[str, Any]:
147+
repository = git.Repo(worktree_dir, search_parent_directories=True)
148+
uni_diff_text = repository.git.diff(None, "HEAD", *files, ignore_blank_lines=True, ignore_space_at_eol=True)
149+
150+
if not uni_diff_text:
151+
logger.warning("No changes found in worktree.")
152+
return {}
153+
154+
if not uni_diff_text.endswith("\n"):
155+
uni_diff_text += "\n"
156+
157+
project_patches_dir = get_patches_dir_for_project()
158+
project_patches_dir.mkdir(parents=True, exist_ok=True)
159+
160+
final_function_name = fto_name or metadata_input.get("fto_name", "unknown")
161+
patch_path = project_patches_dir / f"{worktree_dir.name}.{final_function_name}.patch"
162+
with patch_path.open("w", encoding="utf8") as f:
163+
f.write(uni_diff_text)
164+
165+
final_metadata = {"patch_path": str(patch_path)}
166+
if metadata_input:
167+
final_metadata.update(metadata_input)
168+
final_metadata = save_patches_metadata(final_metadata)
169+
170+
return final_metadata

codeflash/code_utils/shell_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
SHELL_RC_EXPORT_PATTERN = re.compile(r"^set CODEFLASH_API_KEY=(cf-.*)$", re.MULTILINE)
1616
SHELL_RC_EXPORT_PREFIX = "set CODEFLASH_API_KEY="
1717
else:
18-
SHELL_RC_EXPORT_PATTERN = re.compile(r'^(?!#)export CODEFLASH_API_KEY=[\'"]?(cf-[^\s"]+)[\'"]$', re.MULTILINE)
18+
SHELL_RC_EXPORT_PATTERN = re.compile(
19+
r'^(?!#)export CODEFLASH_API_KEY=(?:"|\')?(cf-[^\s"\']+)(?:"|\')?$', re.MULTILINE
20+
)
1921
SHELL_RC_EXPORT_PREFIX = "export CODEFLASH_API_KEY="
2022

2123

codeflash/context/unused_definition_remover.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,7 @@ def revert_unused_helper_functions(
537537
module_abspath=file_path,
538538
preexisting_objects=set(), # Empty set since we're reverting
539539
project_root_path=project_root,
540+
should_add_global_assignments=False, # since we revert helpers functions after applying the optimization, we know that the file already has global assignments added, otherwise they would be added twice.
540541
)
541542

542543
if reverted_code:

0 commit comments

Comments
 (0)