Skip to content

Commit 8906a72

Browse files
Merge pull request #363 from codeflash-ai/part-1-windows-fixes
path normalization and tempdir fixes for windows
2 parents 565d65b + 91870c0 commit 8906a72

36 files changed

+1313
-1045
lines changed

.github/workflows/unit-tests.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ jobs:
2424
uses: astral-sh/setup-uv@v5
2525
with:
2626
python-version: ${{ matrix.python-version }}
27-
version: "0.5.30"
2827

2928
- name: install dependencies
3029
run: uv sync
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
name: windows-unit-tests
2+
3+
on:
4+
push:
5+
branches: [main]
6+
pull_request:
7+
workflow_dispatch:
8+
9+
jobs:
10+
windows-unit-tests:
11+
continue-on-error: true
12+
runs-on: windows-latest
13+
env:
14+
PYTHONIOENCODING: utf-8
15+
steps:
16+
- uses: actions/checkout@v4
17+
with:
18+
fetch-depth: 0
19+
token: ${{ secrets.GITHUB_TOKEN }}
20+
21+
- name: Install uv
22+
uses: astral-sh/setup-uv@v5
23+
with:
24+
python-version: "3.13"
25+
26+
- name: install dependencies
27+
run: uv sync
28+
29+
- name: Unit tests
30+
run: uv run pytest tests/

codeflash/benchmarking/codeflash_trace.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import sqlite3
55
import threading
66
import time
7+
from pathlib import Path
78
from typing import Any, Callable
89

910
from codeflash.picklepatch.pickle_patcher import PicklePatcher
@@ -143,12 +144,13 @@ def wrapper(*args, **kwargs) -> Any: # noqa: ANN002, ANN003, ANN401
143144
print("Pickle limit reached")
144145
self._thread_local.active_functions.remove(func_id)
145146
overhead_time = time.thread_time_ns() - end_time
147+
normalized_file_path = Path(func.__code__.co_filename).as_posix()
146148
self.function_calls_data.append(
147149
(
148150
func.__name__,
149151
class_name,
150152
func.__module__,
151-
func.__code__.co_filename,
153+
normalized_file_path,
152154
benchmark_function_name,
153155
benchmark_module_path,
154156
benchmark_line_number,
@@ -169,12 +171,13 @@ def wrapper(*args, **kwargs) -> Any: # noqa: ANN002, ANN003, ANN401
169171
# Add to the list of function calls without pickled args. Used for timing info only
170172
self._thread_local.active_functions.remove(func_id)
171173
overhead_time = time.thread_time_ns() - end_time
174+
normalized_file_path = Path(func.__code__.co_filename).as_posix()
172175
self.function_calls_data.append(
173176
(
174177
func.__name__,
175178
class_name,
176179
func.__module__,
177-
func.__code__.co_filename,
180+
normalized_file_path,
178181
benchmark_function_name,
179182
benchmark_module_path,
180183
benchmark_line_number,
@@ -192,12 +195,13 @@ def wrapper(*args, **kwargs) -> Any: # noqa: ANN002, ANN003, ANN401
192195
# Add to the list of function calls with pickled args, to be used for replay tests
193196
self._thread_local.active_functions.remove(func_id)
194197
overhead_time = time.thread_time_ns() - end_time
198+
normalized_file_path = Path(func.__code__.co_filename).as_posix()
195199
self.function_calls_data.append(
196200
(
197201
func.__name__,
198202
class_name,
199203
func.__module__,
200-
func.__code__.co_filename,
204+
normalized_file_path,
201205
benchmark_function_name,
202206
benchmark_module_path,
203207
benchmark_line_number,

codeflash/benchmarking/replay_test.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,19 +30,24 @@ def get_next_arg_and_return(
3030
cur = db.cursor()
3131
limit = num_to_get
3232

33+
normalized_file_path = Path(file_path).as_posix()
34+
3335
if class_name is not None:
3436
cursor = cur.execute(
3537
"SELECT * FROM benchmark_function_timings WHERE benchmark_function_name = ? AND function_name = ? AND file_path = ? AND class_name = ? LIMIT ?",
36-
(benchmark_function_name, function_name, file_path, class_name, limit),
38+
(benchmark_function_name, function_name, normalized_file_path, class_name, limit),
3739
)
3840
else:
3941
cursor = cur.execute(
4042
"SELECT * FROM benchmark_function_timings WHERE benchmark_function_name = ? AND function_name = ? AND file_path = ? AND class_name = '' LIMIT ?",
41-
(benchmark_function_name, function_name, file_path, limit),
43+
(benchmark_function_name, function_name, normalized_file_path, limit),
4244
)
4345

44-
while (val := cursor.fetchone()) is not None:
45-
yield val[9], val[10] # pickled_args, pickled_kwargs
46+
try:
47+
while (val := cursor.fetchone()) is not None:
48+
yield val[9], val[10] # pickled_args, pickled_kwargs
49+
finally:
50+
db.close()
4651

4752

4853
def get_function_alias(module: str, function_name: str) -> str:
@@ -166,7 +171,7 @@ def create_trace_replay_test_code(
166171
module_name = func.get("module_name")
167172
function_name = func.get("function_name")
168173
class_name = func.get("class_name")
169-
file_path = func.get("file_path")
174+
file_path = Path(func.get("file_path")).as_posix()
170175
benchmark_function_name = func.get("benchmark_function_name")
171176
function_properties = func.get("function_properties")
172177
if not class_name:

codeflash/code_utils/checkpoint.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import datetime
44
import json
5-
import sys
65
import time
76
import uuid
87
from pathlib import Path
@@ -11,13 +10,16 @@
1110
from rich.prompt import Confirm
1211

1312
from codeflash.cli_cmds.console import console
13+
from codeflash.code_utils.compat import codeflash_temp_dir
1414

1515
if TYPE_CHECKING:
1616
import argparse
1717

1818

1919
class CodeflashRunCheckpoint:
20-
def __init__(self, module_root: Path, checkpoint_dir: Path = Path("/tmp")) -> None: # noqa: S108
20+
def __init__(self, module_root: Path, checkpoint_dir: Path | None = None) -> None:
21+
if checkpoint_dir is None:
22+
checkpoint_dir = codeflash_temp_dir
2123
self.module_root = module_root
2224
self.checkpoint_dir = Path(checkpoint_dir)
2325
# Create a unique checkpoint file name
@@ -37,7 +39,7 @@ def _initialize_checkpoint_file(self) -> None:
3739
"last_updated": time.time(),
3840
}
3941

40-
with self.checkpoint_path.open("w") as f:
42+
with self.checkpoint_path.open("w", encoding="utf-8") as f:
4143
f.write(json.dumps(metadata) + "\n")
4244

4345
def add_function_to_checkpoint(
@@ -66,7 +68,7 @@ def add_function_to_checkpoint(
6668
**additional_info,
6769
}
6870

69-
with self.checkpoint_path.open("a") as f:
71+
with self.checkpoint_path.open("a", encoding="utf-8") as f:
7072
f.write(json.dumps(function_data) + "\n")
7173

7274
# Update the metadata last_updated timestamp
@@ -75,7 +77,7 @@ def add_function_to_checkpoint(
7577
def _update_metadata_timestamp(self) -> None:
7678
"""Update the last_updated timestamp in the metadata."""
7779
# Read the first line (metadata)
78-
with self.checkpoint_path.open() as f:
80+
with self.checkpoint_path.open(encoding="utf-8") as f:
7981
metadata = json.loads(f.readline())
8082
rest_content = f.read()
8183

@@ -84,7 +86,7 @@ def _update_metadata_timestamp(self) -> None:
8486

8587
# Write all lines to a temporary file
8688

87-
with self.checkpoint_path.open("w") as f:
89+
with self.checkpoint_path.open("w", encoding="utf-8") as f:
8890
f.write(json.dumps(metadata) + "\n")
8991
f.write(rest_content)
9092

@@ -94,7 +96,7 @@ def cleanup(self) -> None:
9496
self.checkpoint_path.unlink(missing_ok=True)
9597

9698
for file in self.checkpoint_dir.glob("codeflash_checkpoint_*.jsonl"):
97-
with file.open() as f:
99+
with file.open(encoding="utf-8") as f:
98100
# Skip the first line (metadata)
99101
first_line = next(f)
100102
metadata = json.loads(first_line)
@@ -116,7 +118,7 @@ def get_all_historical_functions(module_root: Path, checkpoint_dir: Path) -> dic
116118
to_delete = []
117119

118120
for file in checkpoint_dir.glob("codeflash_checkpoint_*.jsonl"):
119-
with file.open() as f:
121+
with file.open(encoding="utf-8") as f:
120122
# Skip the first line (metadata)
121123
first_line = next(f)
122124
metadata = json.loads(first_line)
@@ -139,8 +141,8 @@ def get_all_historical_functions(module_root: Path, checkpoint_dir: Path) -> dic
139141

140142
def ask_should_use_checkpoint_get_functions(args: argparse.Namespace) -> Optional[dict[str, dict[str, str]]]:
141143
previous_checkpoint_functions = None
142-
if args.all and (sys.platform == "linux" or sys.platform == "darwin") and Path("/tmp").is_dir(): # noqa: S108 #TODO: use the temp dir from codeutils-compat.py
143-
previous_checkpoint_functions = get_all_historical_functions(args.module_root, Path("/tmp")) # noqa: S108
144+
if args.all and codeflash_temp_dir.is_dir():
145+
previous_checkpoint_functions = get_all_historical_functions(args.module_root, codeflash_temp_dir)
144146
if previous_checkpoint_functions and Confirm.ask(
145147
"Previous Checkpoint detected from an incomplete optimization run, shall I continue the optimization from that point?",
146148
default=True,

codeflash/code_utils/code_utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def module_name_from_file_path(file_path: Path, project_root_path: Path, *, trav
180180
parent = file_path.parent
181181
while parent not in (project_root_path, parent.parent):
182182
try:
183-
relative_path = file_path.relative_to(parent)
183+
relative_path = file_path.resolve().relative_to(parent.resolve())
184184
return relative_path.with_suffix("").as_posix().replace("/", ".")
185185
except ValueError:
186186
parent = parent.parent
@@ -245,8 +245,9 @@ def get_run_tmp_file(file_path: Path) -> Path:
245245

246246

247247
def path_belongs_to_site_packages(file_path: Path) -> bool:
248-
site_packages = [Path(p) for p in site.getsitepackages()]
249-
return any(file_path.resolve().is_relative_to(site_package_path) for site_package_path in site_packages)
248+
file_path_resolved = file_path.resolve()
249+
site_packages = [Path(p).resolve() for p in site.getsitepackages()]
250+
return any(file_path_resolved.is_relative_to(site_package_path) for site_package_path in site_packages)
250251

251252

252253
def is_class_defined_in_file(class_name: str, file_path: Path) -> bool:

codeflash/code_utils/coverage_utils.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,17 +44,25 @@ def build_fully_qualified_name(function_name: str, code_context: CodeOptimizatio
4444
def generate_candidates(source_code_path: Path) -> set[str]:
4545
"""Generate all the possible candidates for coverage data based on the source code path."""
4646
candidates = set()
47-
candidates.add(source_code_path.name)
48-
current_path = source_code_path.parent
49-
50-
last_added = source_code_path.name
51-
while current_path != current_path.parent:
52-
candidate_path = str(Path(current_path.name) / last_added)
47+
# Add the filename as a candidate
48+
name = source_code_path.name
49+
candidates.add(name)
50+
51+
# Precompute parts for efficient candidate path construction
52+
parts = source_code_path.parts
53+
n = len(parts)
54+
55+
# Walk up the directory structure without creating Path objects or repeatedly converting to posix
56+
last_added = name
57+
# Start from the last parent and move up to the root, exclusive (skip the root itself)
58+
for i in range(n - 2, 0, -1):
59+
# Combine the ith part with the accumulated path (last_added)
60+
candidate_path = f"{parts[i]}/{last_added}"
5361
candidates.add(candidate_path)
5462
last_added = candidate_path
55-
current_path = current_path.parent
5663

57-
candidates.add(str(source_code_path))
64+
# Add the absolute posix path as a candidate
65+
candidates.add(source_code_path.as_posix())
5866
return candidates
5967

6068

codeflash/code_utils/env_utils.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,17 @@ def check_formatter_installed(formatter_cmds: list[str], exit_on_failure: bool =
1818
if formatter_cmds[0] == "disabled":
1919
return return_code
2020
tmp_code = """print("hello world")"""
21-
with tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", suffix=".py") as f:
22-
f.write(tmp_code)
23-
f.flush()
24-
tmp_file = Path(f.name)
21+
with tempfile.TemporaryDirectory() as tmpdir:
22+
tmp_file = Path(tmpdir) / "test_codeflash_formatter.py"
23+
tmp_file.write_text(tmp_code, encoding="utf-8")
2524
try:
2625
format_code(formatter_cmds, tmp_file, print_status=False, exit_on_failure=exit_on_failure)
2726
except Exception:
2827
exit_with_message(
2928
"⚠️ Codeflash requires a code formatter to be installed in your environment, but none was found. Please install a supported formatter, verify the formatter-cmds in your codeflash pyproject.toml config and try again.",
3029
error_on_exit=True,
3130
)
32-
return return_code
31+
return return_code
3332

3433

3534
@lru_cache(maxsize=1)
@@ -121,7 +120,7 @@ def get_cached_gh_event_data() -> dict[str, Any]:
121120
event_path = os.getenv("GITHUB_EVENT_PATH")
122121
if not event_path:
123122
return {}
124-
with Path(event_path).open() as f:
123+
with Path(event_path).open(encoding="utf-8") as f:
125124
return json.load(f) # type: ignore # noqa
126125

127126

codeflash/code_utils/instrument_existing_tests.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import ast
4+
import platform
45
from pathlib import Path
56
from typing import TYPE_CHECKING
67

@@ -143,7 +144,10 @@ def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
143144
def visit_FunctionDef(self, node: ast.FunctionDef, test_class_name: str | None = None) -> ast.FunctionDef:
144145
if node.name.startswith("test_"):
145146
did_update = False
146-
if self.test_framework == "unittest":
147+
if self.test_framework == "unittest" and platform.system() != "Windows":
148+
# Only add timeout decorator on non-Windows platforms
149+
# Windows doesn't support SIGALRM signal required by timeout_decorator
150+
147151
node.decorator_list.append(
148152
ast.Call(
149153
func=ast.Name(id="timeout_decorator.timeout", ctx=ast.Load()),
@@ -220,7 +224,9 @@ def visit_FunctionDef(self, node: ast.FunctionDef, test_class_name: str | None =
220224
args=[
221225
ast.JoinedStr(
222226
values=[
223-
ast.Constant(value=f"{get_run_tmp_file(Path('test_return_values_'))}"),
227+
ast.Constant(
228+
value=f"{get_run_tmp_file(Path('test_return_values_')).as_posix()}"
229+
),
224230
ast.FormattedValue(
225231
value=ast.Name(id="codeflash_iteration", ctx=ast.Load()),
226232
conversion=-1,
@@ -588,7 +594,7 @@ def inject_profiling_into_existing_test(
588594
new_imports.extend(
589595
[ast.Import(names=[ast.alias(name="sqlite3")]), ast.Import(names=[ast.alias(name="dill", asname="pickle")])]
590596
)
591-
if test_framework == "unittest":
597+
if test_framework == "unittest" and platform.system() != "Windows":
592598
new_imports.append(ast.Import(names=[ast.alias(name="timeout_decorator")]))
593599
additional_functions = [create_wrapper_function(mode)]
594600

codeflash/code_utils/line_profile_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,6 @@ def add_decorator_imports(function_to_optimize: FunctionToOptimize, code_context
219219
file.write(modified_code)
220220
# Adding profile.enable line for changing the savepath of the data, do this only for the main file and not the helper files
221221
file_contents = function_to_optimize.file_path.read_text("utf-8")
222-
modified_code = add_profile_enable(file_contents, str(line_profile_output_file))
222+
modified_code = add_profile_enable(file_contents, line_profile_output_file.as_posix())
223223
function_to_optimize.file_path.write_text(modified_code, "utf-8")
224224
return line_profile_output_file

0 commit comments

Comments
 (0)