Skip to content

Commit 258b523

Browse files
Merge pull request #693 from codeflash-ai/parallel-pytest-tracing
Parallel pytest tracing
2 parents 922f714 + aab93ef commit 258b523

File tree

2 files changed

+172
-40
lines changed

2 files changed

+172
-40
lines changed

codeflash/tracer.py

Lines changed: 89 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
import json
1515
import pickle
1616
import subprocess
17+
import time
18+
1719
import sys
1820
from argparse import ArgumentParser
1921
from pathlib import Path
@@ -24,6 +26,7 @@
2426
from codeflash.code_utils.code_utils import get_run_tmp_file
2527
from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE
2628
from codeflash.code_utils.config_parser import parse_config_file
29+
from codeflash.tracing.pytest_parallelization import pytest_split
2730

2831
if TYPE_CHECKING:
2932
from argparse import Namespace
@@ -86,51 +89,97 @@ def main(args: Namespace | None = None) -> ArgumentParser:
8689
config, found_config_path = parse_config_file(parsed_args.codeflash_config)
8790
project_root = project_root_from_module_root(Path(config["module_root"]), found_config_path)
8891
if len(unknown_args) > 0:
92+
args_dict = {
93+
"functions": parsed_args.only_functions,
94+
"disable": False,
95+
"project_root": str(project_root),
96+
"max_function_count": parsed_args.max_function_count,
97+
"timeout": parsed_args.tracer_timeout,
98+
"progname": unknown_args[0],
99+
"config": config,
100+
"module": parsed_args.module,
101+
}
89102
try:
90-
result_pickle_file_path = get_run_tmp_file("tracer_results_file.pkl")
91-
args_dict = {
92-
"result_pickle_file_path": str(result_pickle_file_path),
93-
"output": str(parsed_args.outfile),
94-
"functions": parsed_args.only_functions,
95-
"disable": False,
96-
"project_root": str(project_root),
97-
"max_function_count": parsed_args.max_function_count,
98-
"timeout": parsed_args.tracer_timeout,
99-
"command": " ".join(sys.argv),
100-
"progname": unknown_args[0],
101-
"config": config,
102-
"module": parsed_args.module,
103-
}
104-
105-
subprocess.run(
106-
[
107-
SAFE_SYS_EXECUTABLE,
108-
Path(__file__).parent / "tracing" / "tracing_new_process.py",
109-
*sys.argv,
110-
json.dumps(args_dict),
111-
],
112-
cwd=Path.cwd(),
113-
check=False,
114-
)
115-
try:
116-
with result_pickle_file_path.open(mode="rb") as f:
117-
data = pickle.load(f)
118-
except Exception:
119-
console.print("❌ Failed to trace. Exiting...")
120-
sys.exit(1)
121-
finally:
122-
result_pickle_file_path.unlink(missing_ok=True)
123-
124-
replay_test_path = data["replay_test_file_path"]
125-
if not parsed_args.trace_only and replay_test_path is not None:
103+
pytest_splits = []
104+
test_paths = []
105+
replay_test_paths = []
106+
if parsed_args.module and unknown_args[0] == "pytest":
107+
pytest_splits, test_paths = pytest_split(unknown_args[1:])
108+
109+
if len(pytest_splits) > 1:
110+
processes = []
111+
test_paths_set = set(test_paths)
112+
result_pickle_file_paths = []
113+
for i, test_split in enumerate(pytest_splits, start=1):
114+
result_pickle_file_path = get_run_tmp_file(Path(f"tracer_results_file_{i}.pkl"))
115+
result_pickle_file_paths.append(result_pickle_file_path)
116+
args_dict["result_pickle_file_path"] = str(result_pickle_file_path)
117+
outpath = parsed_args.outfile
118+
outpath = outpath.parent / f"{outpath.stem}_{i}{outpath.suffix}"
119+
args_dict["output"] = str(outpath)
120+
updated_sys_argv = []
121+
for elem in sys.argv:
122+
if elem in test_paths_set:
123+
updated_sys_argv.extend(test_split)
124+
else:
125+
updated_sys_argv.append(elem)
126+
args_dict["command"] = " ".join(updated_sys_argv)
127+
processes.append(
128+
subprocess.Popen(
129+
[
130+
SAFE_SYS_EXECUTABLE,
131+
Path(__file__).parent / "tracing" / "tracing_new_process.py",
132+
*updated_sys_argv,
133+
json.dumps(args_dict),
134+
],
135+
cwd=Path.cwd(),
136+
)
137+
)
138+
for process in processes:
139+
process.wait()
140+
for result_pickle_file_path in result_pickle_file_paths:
141+
try:
142+
with result_pickle_file_path.open(mode="rb") as f:
143+
data = pickle.load(f)
144+
replay_test_paths.append(str(data["replay_test_file_path"]))
145+
except Exception:
146+
console.print("❌ Failed to trace. Exiting...")
147+
sys.exit(1)
148+
finally:
149+
result_pickle_file_path.unlink(missing_ok=True)
150+
else:
151+
result_pickle_file_path = get_run_tmp_file(Path("tracer_results_file.pkl"))
152+
args_dict["result_pickle_file_path"] = str(result_pickle_file_path)
153+
args_dict["output"] = str(parsed_args.outfile)
154+
args_dict["command"] = " ".join(sys.argv)
155+
156+
subprocess.run(
157+
[
158+
SAFE_SYS_EXECUTABLE,
159+
Path(__file__).parent / "tracing" / "tracing_new_process.py",
160+
*sys.argv,
161+
json.dumps(args_dict),
162+
],
163+
cwd=Path.cwd(),
164+
check=False,
165+
)
166+
try:
167+
with result_pickle_file_path.open(mode="rb") as f:
168+
data = pickle.load(f)
169+
replay_test_paths.append(str(data["replay_test_file_path"]))
170+
except Exception:
171+
console.print("❌ Failed to trace. Exiting...")
172+
sys.exit(1)
173+
finally:
174+
result_pickle_file_path.unlink(missing_ok=True)
175+
if not parsed_args.trace_only and replay_test_paths:
126176
from codeflash.cli_cmds.cli import parse_args, process_pyproject_config
127177
from codeflash.cli_cmds.cmd_init import CODEFLASH_LOGO
128178
from codeflash.cli_cmds.console import paneled_text
129179
from codeflash.telemetry import posthog_cf
130180
from codeflash.telemetry.sentry import init_sentry
131181

132-
sys.argv = ["codeflash", "--replay-test", str(replay_test_path)]
133-
182+
sys.argv = ["codeflash", "--replay-test", *replay_test_paths]
134183
args = parse_args()
135184
paneled_text(
136185
CODEFLASH_LOGO,
@@ -150,8 +199,8 @@ def main(args: Namespace | None = None) -> ArgumentParser:
150199
# Delete the trace file and the replay test file if they exist
151200
if outfile:
152201
outfile.unlink(missing_ok=True)
153-
if replay_test_path:
154-
replay_test_path.unlink(missing_ok=True)
202+
for replay_test_path in replay_test_paths:
203+
Path(replay_test_path).unlink(missing_ok=True)
155204

156205
except BrokenPipeError as exc:
157206
# Prevent "Exception ignored" during interpreter shutdown.
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
from __future__ import annotations
2+
3+
import os
4+
from math import ceil
5+
from pathlib import Path
6+
from random import shuffle
7+
8+
def pytest_split(
9+
arguments: list[str], num_splits: int | None = None
10+
) -> tuple[list[list[str]] | None, list[str] | None]:
11+
"""Split pytest test files from a directory into N roughly equal groups for parallel execution.
12+
13+
Args:
14+
arguments: List of arguments passed to pytest
15+
test_directory: Path to directory containing test files
16+
num_splits: Number of groups to split tests into. If None, uses CPU count.
17+
18+
Returns:
19+
List of lists, where each inner list contains test file paths for one group.
20+
Returns single list with all tests if number of test files < CPU cores.
21+
22+
"""
23+
try:
24+
import pytest
25+
26+
parser = pytest.Parser()
27+
28+
pytest_args = parser.parse_known_args(arguments)
29+
test_paths = getattr(pytest_args, "file_or_dir", None)
30+
if not test_paths:
31+
return None, None
32+
33+
except ImportError:
34+
return None, None
35+
test_files = set()
36+
37+
# Find all test_*.py files recursively in the directory
38+
for test_path in test_paths:
39+
_test_path = Path(test_path)
40+
if not _test_path.exists():
41+
return None, None
42+
if _test_path.is_dir():
43+
# Find all test files matching the pattern test_*.py
44+
test_files.update(map(str, _test_path.rglob("test_*.py")))
45+
test_files.update(map(str, _test_path.rglob("*_test.py")))
46+
elif _test_path.is_file():
47+
test_files.add(str(_test_path))
48+
49+
if not test_files:
50+
return [[]], None
51+
52+
# Determine number of splits
53+
if num_splits is None:
54+
num_splits = os.cpu_count() or 4
55+
56+
#randomize to increase chances of all splits being balanced
57+
test_files = list(test_files)
58+
shuffle(test_files)
59+
60+
# Ensure each split has at least 4 test files
61+
# If we have fewer test files than 4 * num_splits, reduce num_splits
62+
max_possible_splits = len(test_files) // 4
63+
if max_possible_splits == 0:
64+
return test_files, test_paths
65+
66+
num_splits = min(num_splits, max_possible_splits)
67+
68+
# Calculate chunk size (round up to ensure all files are included)
69+
total_files = len(test_files)
70+
chunk_size = ceil(total_files / num_splits)
71+
72+
# Initialize result groups
73+
result_groups = [[] for _ in range(num_splits)]
74+
75+
# Distribute files across groups
76+
for i, test_file in enumerate(test_files):
77+
group_index = i // chunk_size
78+
# Ensure we don't exceed the number of groups (edge case handling)
79+
if group_index >= num_splits:
80+
group_index = num_splits - 1
81+
result_groups[group_index].append(test_file)
82+
83+
return result_groups, test_paths

0 commit comments

Comments
 (0)