Skip to content

Commit 0dc325a

Browse files
committed
sets instead of lists
1 parent d4788b9 commit 0dc325a

File tree

2 files changed

+13
-17
lines changed

2 files changed

+13
-17
lines changed

codeflash/tracer.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333

3434

3535
def main(args: Namespace | None = None) -> ArgumentParser:
36-
start = time.time()
3736
parser = ArgumentParser(allow_abbrev=False)
3837
parser.add_argument("-o", "--outfile", dest="outfile", help="Save trace to <outfile>", default="codeflash.trace")
3938
parser.add_argument("--only-functions", help="Trace only these functions", nargs="+", default=None)
@@ -106,25 +105,22 @@ def main(args: Namespace | None = None) -> ArgumentParser:
106105
replay_test_paths = []
107106
if parsed_args.module and unknown_args[0] == "pytest":
108107
pytest_splits, test_paths = pytest_split(unknown_args[1:])
109-
print(pytest_splits)
110108

111109
if len(pytest_splits) > 1:
112110
processes = []
113111
test_paths_set = set(test_paths)
114112
result_pickle_file_paths = []
115113
for i, test_split in enumerate(pytest_splits, start=1):
116-
result_pickle_file_path = get_run_tmp_file(f"tracer_results_file_{i}.pkl")
114+
result_pickle_file_path = get_run_tmp_file(Path(f"tracer_results_file_{i}.pkl"))
117115
result_pickle_file_paths.append(result_pickle_file_path)
118116
args_dict["result_pickle_file_path"] = str(result_pickle_file_path)
119117
outpath = parsed_args.outfile
120118
outpath = outpath.parent / f"{outpath.stem}_{i}{outpath.suffix}"
121119
args_dict["output"] = str(outpath)
122-
added_paths = False
123120
updated_sys_argv = []
124121
for elem in sys.argv:
125122
if elem in test_paths_set:
126-
if not added_paths:
127-
updated_sys_argv.extend(test_split)
123+
updated_sys_argv.extend(test_split)
128124
else:
129125
updated_sys_argv.append(elem)
130126
args_dict["command"] = " ".join(updated_sys_argv)
@@ -152,7 +148,7 @@ def main(args: Namespace | None = None) -> ArgumentParser:
152148
finally:
153149
result_pickle_file_path.unlink(missing_ok=True)
154150
else:
155-
result_pickle_file_path = get_run_tmp_file("tracer_results_file.pkl")
151+
result_pickle_file_path = get_run_tmp_file(Path("tracer_results_file.pkl"))
156152
args_dict["result_pickle_file_path"] = str(result_pickle_file_path)
157153
args_dict["output"] = str(parsed_args.outfile)
158154
args_dict["command"] = " ".join(sys.argv)
@@ -176,7 +172,6 @@ def main(args: Namespace | None = None) -> ArgumentParser:
176172
sys.exit(1)
177173
finally:
178174
result_pickle_file_path.unlink(missing_ok=True)
179-
print(f"Took {time.time() - start}")
180175
if not parsed_args.trace_only and replay_test_paths:
181176
from codeflash.cli_cmds.cli import parse_args, process_pyproject_config
182177
from codeflash.cli_cmds.cmd_init import CODEFLASH_LOGO
@@ -185,7 +180,6 @@ def main(args: Namespace | None = None) -> ArgumentParser:
185180
from codeflash.telemetry.sentry import init_sentry
186181

187182
sys.argv = ["codeflash", "--replay-test", *replay_test_paths]
188-
print(sys.argv)
189183
args = parse_args()
190184
paneled_text(
191185
CODEFLASH_LOGO,

codeflash/tracing/pytest_parallelization.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44
from math import ceil
55
from pathlib import Path
6-
6+
from random import shuffle
77

88
def pytest_split(
99
arguments: list[str], num_splits: int | None = None
@@ -32,7 +32,7 @@ def pytest_split(
3232

3333
except ImportError:
3434
return None, None
35-
test_files = []
35+
test_files = set()
3636

3737
# Find all test_*.py files recursively in the directory
3838
for test_path in test_paths:
@@ -41,12 +41,10 @@ def pytest_split(
4141
return None, None
4242
if _test_path.is_dir():
4343
# Find all test files matching the pattern test_*.py
44-
test_files.extend(map(str, _test_path.rglob("test_*.py")))
44+
test_files.update(map(str, _test_path.rglob("test_*.py")))
45+
test_files.update(map(str, _test_path.rglob("*_test.py")))
4546
elif _test_path.is_file():
46-
test_files.append(str(_test_path))
47-
48-
# Sort files for consistent ordering
49-
test_files.sort()
47+
test_files.add(str(_test_path))
5048

5149
if not test_files:
5250
return [[]], None
@@ -55,11 +53,15 @@ def pytest_split(
5553
if num_splits is None:
5654
num_splits = os.cpu_count() or 4
5755

56+
#randomize to increase chances of all splits being balanced
57+
test_files = list(test_files)
58+
shuffle(test_files)
59+
5860
# Ensure each split has at least 4 test files
5961
# If we have fewer test files than 4 * num_splits, reduce num_splits
6062
max_possible_splits = len(test_files) // 4
6163
if max_possible_splits == 0:
62-
return [test_files], test_paths
64+
return test_files, test_paths
6365

6466
num_splits = min(num_splits, max_possible_splits)
6567

0 commit comments

Comments
 (0)