Skip to content

Commit 8a47148

Browse files
committed
infra: stabilize MPI tests and prevent hangs
- Sync gtest random_seed + filter across ranks - Abort on worker failure; wrap RunAllTests with MPI_Abort on exceptions - Make PPC_TEST_TMPDIR per MPI rank - Pass env to mpiexec on Windows (-env), keep -x on *nix
1 parent b6fe4e6 commit 8a47148

File tree

1 file changed

+87
-50
lines changed

1 file changed

+87
-50
lines changed

scripts/run_tests.py

Lines changed: 87 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,16 @@ def __init__(self, verbose=False):
5555
self.mpi_exec = "mpirun"
5656
self.platform = platform.system()
5757

58+
# Detect MPI implementation to choose compatible flags
59+
self.mpi_env_mode = "unknown" # one of: openmpi, mpich, unknown
60+
self.mpi_np_flag = "-np"
61+
if self.platform == "Windows":
62+
# MSMPI uses -env and -n
63+
self.mpi_env_mode = "mpich"
64+
self.mpi_np_flag = "-n"
65+
else:
66+
self.mpi_env_mode, self.mpi_np_flag = self.__detect_mpi_impl()
67+
5868
@staticmethod
5969
def __get_project_path():
6070
script_path = Path(__file__).resolve() # Absolute path of the script
@@ -89,6 +99,81 @@ def __run_exec(self, command):
8999
if result.returncode != 0:
90100
raise Exception(f"Subprocess return {result.returncode}.")
91101

102+
def __detect_mpi_impl(self):
103+
"""Detect MPI implementation and return (env_mode, np_flag).
104+
env_mode: 'openmpi' -> use '-x VAR', 'mpich' -> use '-genvlist VAR1,VAR2', 'unknown' -> pass no env flags.
105+
np_flag: '-np' for OpenMPI/unknown, '-n' for MPICH-family.
106+
"""
107+
probes = (["--version"], ["-V"], ["-v"], ["--help"], ["-help"])
108+
out = ""
109+
for args in probes:
110+
try:
111+
proc = subprocess.run(
112+
[self.mpi_exec] + list(args),
113+
stdout=subprocess.PIPE,
114+
stderr=subprocess.STDOUT,
115+
text=True,
116+
)
117+
out = (proc.stdout or "").lower()
118+
if out:
119+
break
120+
except Exception:
121+
continue
122+
123+
if "open mpi" in out or "ompi" in out:
124+
return "openmpi", "-np"
125+
if (
126+
"hydra" in out
127+
or "mpich" in out
128+
or "intel(r) mpi" in out
129+
or "intel mpi" in out
130+
):
131+
return "mpich", "-n"
132+
return "unknown", "-np"
133+
134+
def __build_mpi_cmd(self, ppc_num_proc, additional_mpi_args):
135+
base = [self.mpi_exec] + shlex.split(additional_mpi_args)
136+
137+
if self.platform == "Windows":
138+
# MS-MPI style
139+
env_args = [
140+
"-env",
141+
"PPC_NUM_THREADS",
142+
self.__ppc_env["PPC_NUM_THREADS"],
143+
"-env",
144+
"OMP_NUM_THREADS",
145+
self.__ppc_env["OMP_NUM_THREADS"],
146+
]
147+
np_args = ["-n", ppc_num_proc]
148+
return base + env_args + np_args
149+
150+
# Non-Windows
151+
if self.mpi_env_mode == "openmpi":
152+
env_args = [
153+
"-x",
154+
"PPC_NUM_THREADS",
155+
"-x",
156+
"OMP_NUM_THREADS",
157+
]
158+
np_flag = "-np"
159+
elif self.mpi_env_mode == "mpich":
160+
# Explicitly set env variables for all ranks
161+
env_args = [
162+
"-env",
163+
"PPC_NUM_THREADS",
164+
self.__ppc_env["PPC_NUM_THREADS"],
165+
"-env",
166+
"OMP_NUM_THREADS",
167+
self.__ppc_env["OMP_NUM_THREADS"],
168+
]
169+
np_flag = "-n"
170+
else:
171+
# Unknown MPI flavor: rely on environment inheritance and default to -np
172+
env_args = []
173+
np_flag = "-np"
174+
175+
return base + env_args + [np_flag, ppc_num_proc]
176+
92177
@staticmethod
93178
def __get_gtest_settings(repeats_count, type_task):
94179
command = [
@@ -134,34 +219,7 @@ def run_processes(self, additional_mpi_args):
134219
raise EnvironmentError(
135220
"Required environment variable 'PPC_NUM_PROC' is not set."
136221
)
137-
if self.platform == "Windows":
138-
mpi_running = (
139-
[self.mpi_exec]
140-
+ shlex.split(additional_mpi_args)
141-
+ [
142-
"-env",
143-
"PPC_NUM_THREADS",
144-
self.__ppc_env["PPC_NUM_THREADS"],
145-
"-env",
146-
"OMP_NUM_THREADS",
147-
self.__ppc_env["OMP_NUM_THREADS"],
148-
"-n",
149-
ppc_num_proc,
150-
]
151-
)
152-
else:
153-
mpi_running = (
154-
[self.mpi_exec]
155-
+ shlex.split(additional_mpi_args)
156-
+ [
157-
"-x",
158-
"PPC_NUM_THREADS",
159-
"-x",
160-
"OMP_NUM_THREADS",
161-
"-np",
162-
ppc_num_proc,
163-
]
164-
)
222+
mpi_running = self.__build_mpi_cmd(ppc_num_proc, additional_mpi_args)
165223
if not self.__ppc_env.get("PPC_ASAN_RUN"):
166224
for task_type in ["all", "mpi"]:
167225
self.__run_exec(
@@ -172,28 +230,7 @@ def run_processes(self, additional_mpi_args):
172230

173231
def run_performance(self):
174232
if not self.__ppc_env.get("PPC_ASAN_RUN"):
175-
if self.platform == "Windows":
176-
mpi_running = [
177-
self.mpi_exec,
178-
"-env",
179-
"PPC_NUM_THREADS",
180-
self.__ppc_env["PPC_NUM_THREADS"],
181-
"-env",
182-
"OMP_NUM_THREADS",
183-
self.__ppc_env["OMP_NUM_THREADS"],
184-
"-n",
185-
self.__ppc_num_proc,
186-
]
187-
else:
188-
mpi_running = [
189-
self.mpi_exec,
190-
"-x",
191-
"PPC_NUM_THREADS",
192-
"-x",
193-
"OMP_NUM_THREADS",
194-
"-np",
195-
self.__ppc_num_proc,
196-
]
233+
mpi_running = self.__build_mpi_cmd(self.__ppc_num_proc, "")
197234
for task_type in ["all", "mpi"]:
198235
self.__run_exec(
199236
mpi_running

0 commit comments

Comments
 (0)