@@ -55,6 +55,16 @@ def __init__(self, verbose=False):
5555 self .mpi_exec = "mpirun"
5656 self .platform = platform .system ()
5757
58+ # Detect MPI implementation to choose compatible flags
59+ self .mpi_env_mode = "unknown" # one of: openmpi, mpich, unknown
60+ self .mpi_np_flag = "-np"
61+ if self .platform == "Windows" :
62+ # MSMPI uses -env and -n
63+ self .mpi_env_mode = "mpich"
64+ self .mpi_np_flag = "-n"
65+ else :
66+ self .mpi_env_mode , self .mpi_np_flag = self .__detect_mpi_impl ()
67+
5868 @staticmethod
5969 def __get_project_path ():
6070 script_path = Path (__file__ ).resolve () # Absolute path of the script
@@ -89,6 +99,81 @@ def __run_exec(self, command):
8999 if result .returncode != 0 :
90100 raise Exception (f"Subprocess return { result .returncode } ." )
91101
102+ def __detect_mpi_impl (self ):
103+ """Detect MPI implementation and return (env_mode, np_flag).
104+ env_mode: 'openmpi' -> use '-x VAR', 'mpich' -> use '-genvlist VAR1,VAR2', 'unknown' -> pass no env flags.
105+ np_flag: '-np' for OpenMPI/unknown, '-n' for MPICH-family.
106+ """
107+ probes = (["--version" ], ["-V" ], ["-v" ], ["--help" ], ["-help" ])
108+ out = ""
109+ for args in probes :
110+ try :
111+ proc = subprocess .run (
112+ [self .mpi_exec ] + list (args ),
113+ stdout = subprocess .PIPE ,
114+ stderr = subprocess .STDOUT ,
115+ text = True ,
116+ )
117+ out = (proc .stdout or "" ).lower ()
118+ if out :
119+ break
120+ except Exception :
121+ continue
122+
123+ if "open mpi" in out or "ompi" in out :
124+ return "openmpi" , "-np"
125+ if (
126+ "hydra" in out
127+ or "mpich" in out
128+ or "intel(r) mpi" in out
129+ or "intel mpi" in out
130+ ):
131+ return "mpich" , "-n"
132+ return "unknown" , "-np"
133+
134+ def __build_mpi_cmd (self , ppc_num_proc , additional_mpi_args ):
135+ base = [self .mpi_exec ] + shlex .split (additional_mpi_args )
136+
137+ if self .platform == "Windows" :
138+ # MS-MPI style
139+ env_args = [
140+ "-env" ,
141+ "PPC_NUM_THREADS" ,
142+ self .__ppc_env ["PPC_NUM_THREADS" ],
143+ "-env" ,
144+ "OMP_NUM_THREADS" ,
145+ self .__ppc_env ["OMP_NUM_THREADS" ],
146+ ]
147+ np_args = ["-n" , ppc_num_proc ]
148+ return base + env_args + np_args
149+
150+ # Non-Windows
151+ if self .mpi_env_mode == "openmpi" :
152+ env_args = [
153+ "-x" ,
154+ "PPC_NUM_THREADS" ,
155+ "-x" ,
156+ "OMP_NUM_THREADS" ,
157+ ]
158+ np_flag = "-np"
159+ elif self .mpi_env_mode == "mpich" :
160+ # Explicitly set env variables for all ranks
161+ env_args = [
162+ "-env" ,
163+ "PPC_NUM_THREADS" ,
164+ self .__ppc_env ["PPC_NUM_THREADS" ],
165+ "-env" ,
166+ "OMP_NUM_THREADS" ,
167+ self .__ppc_env ["OMP_NUM_THREADS" ],
168+ ]
169+ np_flag = "-n"
170+ else :
171+ # Unknown MPI flavor: rely on environment inheritance and default to -np
172+ env_args = []
173+ np_flag = "-np"
174+
175+ return base + env_args + [np_flag , ppc_num_proc ]
176+
92177 @staticmethod
93178 def __get_gtest_settings (repeats_count , type_task ):
94179 command = [
@@ -134,34 +219,7 @@ def run_processes(self, additional_mpi_args):
134219 raise EnvironmentError (
135220 "Required environment variable 'PPC_NUM_PROC' is not set."
136221 )
137- if self .platform == "Windows" :
138- mpi_running = (
139- [self .mpi_exec ]
140- + shlex .split (additional_mpi_args )
141- + [
142- "-env" ,
143- "PPC_NUM_THREADS" ,
144- self .__ppc_env ["PPC_NUM_THREADS" ],
145- "-env" ,
146- "OMP_NUM_THREADS" ,
147- self .__ppc_env ["OMP_NUM_THREADS" ],
148- "-n" ,
149- ppc_num_proc ,
150- ]
151- )
152- else :
153- mpi_running = (
154- [self .mpi_exec ]
155- + shlex .split (additional_mpi_args )
156- + [
157- "-x" ,
158- "PPC_NUM_THREADS" ,
159- "-x" ,
160- "OMP_NUM_THREADS" ,
161- "-np" ,
162- ppc_num_proc ,
163- ]
164- )
222+ mpi_running = self .__build_mpi_cmd (ppc_num_proc , additional_mpi_args )
165223 if not self .__ppc_env .get ("PPC_ASAN_RUN" ):
166224 for task_type in ["all" , "mpi" ]:
167225 self .__run_exec (
@@ -172,28 +230,7 @@ def run_processes(self, additional_mpi_args):
172230
173231 def run_performance (self ):
174232 if not self .__ppc_env .get ("PPC_ASAN_RUN" ):
175- if self .platform == "Windows" :
176- mpi_running = [
177- self .mpi_exec ,
178- "-env" ,
179- "PPC_NUM_THREADS" ,
180- self .__ppc_env ["PPC_NUM_THREADS" ],
181- "-env" ,
182- "OMP_NUM_THREADS" ,
183- self .__ppc_env ["OMP_NUM_THREADS" ],
184- "-n" ,
185- self .__ppc_num_proc ,
186- ]
187- else :
188- mpi_running = [
189- self .mpi_exec ,
190- "-x" ,
191- "PPC_NUM_THREADS" ,
192- "-x" ,
193- "OMP_NUM_THREADS" ,
194- "-np" ,
195- self .__ppc_num_proc ,
196- ]
233+ mpi_running = self .__build_mpi_cmd (self .__ppc_num_proc , "" )
197234 for task_type in ["all" , "mpi" ]:
198235 self .__run_exec (
199236 mpi_running
0 commit comments