diff --git a/modules/runners/src/runners.cpp b/modules/runners/src/runners.cpp index df735d8e..5ab4e8e6 100644 --- a/modules/runners/src/runners.cpp +++ b/modules/runners/src/runners.cpp @@ -6,12 +6,14 @@ #include #include #include +#include #include #include #include #include #include #include +#include #include "oneapi/tbb/global_control.h" #include "util/include/util.hpp" @@ -51,6 +53,8 @@ void WorkerTestFailurePrinter::OnTestEnd(const ::testing::TestInfo &test_info) { } PrintProcessRank(); base_->OnTestEnd(test_info); + // Abort the whole MPI job on any test failure to avoid other ranks hanging on barriers. + MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE); } void WorkerTestFailurePrinter::OnTestPartResult(const ::testing::TestPartResult &test_part_result) { @@ -76,6 +80,63 @@ int RunAllTests() { } return status; } + +void SyncGTestSeed() { + unsigned int seed = 0; + int rank = -1; + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + if (rank == 0) { + try { + seed = std::random_device{}(); + } catch (...) { + seed = 0; + } + if (seed == 0) { + const auto now = static_cast(std::chrono::steady_clock::now().time_since_epoch().count()); + seed = static_cast(((now & 0x7fffffffULL) | 1ULL)); + } + } + MPI_Bcast(&seed, 1, MPI_UNSIGNED, 0, MPI_COMM_WORLD); + ::testing::GTEST_FLAG(random_seed) = static_cast(seed); +} + +void SyncGTestFilter() { + int rank = -1; + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + std::string filter = (rank == 0) ? ::testing::GTEST_FLAG(filter) : std::string{}; + int len = static_cast(filter.size()); + MPI_Bcast(&len, 1, MPI_INT, 0, MPI_COMM_WORLD); + if (rank != 0) { + filter.resize(static_cast(len)); + } + if (len > 0) { + MPI_Bcast(filter.data(), len, MPI_CHAR, 0, MPI_COMM_WORLD); + } + ::testing::GTEST_FLAG(filter) = filter; +} + +bool HasFlag(int argc, char **argv, std::string_view flag) { + for (int i = 1; i < argc; ++i) { + if (argv[i] != nullptr && std::string_view(argv[i]) == flag) { + return true; + } + } + return false; +} + +int RunAllTestsSafely() { + try { + return RunAllTests(); + } catch (const std::exception &e) { + std::cerr << std::format("[ ERROR ] Exception after tests: {}", e.what()) << '\n'; + MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE); + return EXIT_FAILURE; + } catch (...) { + std::cerr << "[ ERROR ] Unknown exception after tests" << '\n'; + MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE); + return EXIT_FAILURE; + } +} } // namespace int Init(int argc, char **argv) { @@ -91,36 +152,21 @@ int Init(int argc, char **argv) { ::testing::InitGoogleTest(&argc, argv); - // Ensure consistent GoogleTest shuffle order across all MPI ranks. - unsigned int seed = 0; - int rank_for_seed = -1; - MPI_Comm_rank(MPI_COMM_WORLD, &rank_for_seed); - - if (rank_for_seed == 0) { - try { - seed = std::random_device{}(); - } catch (...) { - seed = 0; - } - if (seed == 0) { - const auto now = static_cast(std::chrono::steady_clock::now().time_since_epoch().count()); - seed = static_cast(((now & 0x7fffffffULL) | 1ULL)); - } - } - - MPI_Bcast(&seed, 1, MPI_UNSIGNED, 0, MPI_COMM_WORLD); - ::testing::GTEST_FLAG(random_seed) = static_cast(seed); + // Synchronize GoogleTest internals across ranks to avoid divergence + SyncGTestSeed(); + SyncGTestFilter(); auto &listeners = ::testing::UnitTest::GetInstance()->listeners(); int rank = -1; MPI_Comm_rank(MPI_COMM_WORLD, &rank); - if (rank != 0 && (argc < 2 || argv[1] != std::string("--print-workers"))) { + const bool print_workers = HasFlag(argc, argv, "--print-workers"); + if (rank != 0 && !print_workers) { auto *listener = listeners.Release(listeners.default_result_printer()); listeners.Append(new WorkerTestFailurePrinter(std::shared_ptr<::testing::TestEventListener>(listener))); } listeners.Append(new UnreadMessagesDetector()); - auto status = RunAllTests(); + const int status = RunAllTestsSafely(); const int finalize_res = MPI_Finalize(); if (finalize_res != MPI_SUCCESS) { diff --git a/modules/util/include/util.hpp b/modules/util/include/util.hpp index 6b1846ca..ca93fd9b 100644 --- a/modules/util/include/util.hpp +++ b/modules/util/include/util.hpp @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include #include @@ -26,6 +27,7 @@ #include #include +#include #include /// @brief JSON namespace used for settings and config parsing. @@ -123,7 +125,19 @@ class ScopedPerTestEnv { private: static std::string CreateTmpDir(const std::string &token) { namespace fs = std::filesystem; - const fs::path tmp = fs::temp_directory_path() / (std::string("ppc_test_") + token); + auto make_rank_suffix = []() -> std::string { + // Derive rank from common MPI env vars without including MPI headers + constexpr std::array kRankVars = {"OMPI_COMM_WORLD_RANK", "PMI_RANK", "PMIX_RANK", + "SLURM_PROCID", "MSMPI_RANK"}; + for (auto name : kRankVars) { + if (auto r = env::get(name); r.has_value() && r.value() >= 0) { + return std::string("_rank_") + std::to_string(r.value()); + } + } + return std::string{}; + }; + const std::string rank_suffix = IsUnderMpirun() ? make_rank_suffix() : std::string{}; + const fs::path tmp = fs::temp_directory_path() / (std::string("ppc_test_") + token + rank_suffix); std::error_code ec; fs::create_directories(tmp, ec); (void)ec; diff --git a/scripts/run_tests.py b/scripts/run_tests.py index ea2ea3bb..f8477c56 100755 --- a/scripts/run_tests.py +++ b/scripts/run_tests.py @@ -53,6 +53,17 @@ def __init__(self, verbose=False): self.mpi_exec = "mpiexec" else: self.mpi_exec = "mpirun" + self.platform = platform.system() + + # Detect MPI implementation to choose compatible flags + self.mpi_env_mode = "unknown" # one of: openmpi, mpich, unknown + self.mpi_np_flag = "-np" + if self.platform == "Windows": + # MSMPI uses -env and -n + self.mpi_env_mode = "mpich" + self.mpi_np_flag = "-n" + else: + self.mpi_env_mode, self.mpi_np_flag = self.__detect_mpi_impl() @staticmethod def __get_project_path(): @@ -88,6 +99,81 @@ def __run_exec(self, command): if result.returncode != 0: raise Exception(f"Subprocess return {result.returncode}.") + def __detect_mpi_impl(self): + """Detect MPI implementation and return (env_mode, np_flag). + env_mode: 'openmpi' -> use '-x VAR', 'mpich' -> use '-genvlist VAR1,VAR2', 'unknown' -> pass no env flags. + np_flag: '-np' for OpenMPI/unknown, '-n' for MPICH-family. + """ + probes = (["--version"], ["-V"], ["-v"], ["--help"], ["-help"]) + out = "" + for args in probes: + try: + proc = subprocess.run( + [self.mpi_exec] + list(args), + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + ) + out = (proc.stdout or "").lower() + if out: + break + except Exception: + continue + + if "open mpi" in out or "ompi" in out: + return "openmpi", "-np" + if ( + "hydra" in out + or "mpich" in out + or "intel(r) mpi" in out + or "intel mpi" in out + ): + return "mpich", "-n" + return "unknown", "-np" + + def __build_mpi_cmd(self, ppc_num_proc, additional_mpi_args): + base = [self.mpi_exec] + shlex.split(additional_mpi_args) + + if self.platform == "Windows": + # MS-MPI style + env_args = [ + "-env", + "PPC_NUM_THREADS", + self.__ppc_env["PPC_NUM_THREADS"], + "-env", + "OMP_NUM_THREADS", + self.__ppc_env["OMP_NUM_THREADS"], + ] + np_args = ["-n", ppc_num_proc] + return base + env_args + np_args + + # Non-Windows + if self.mpi_env_mode == "openmpi": + env_args = [ + "-x", + "PPC_NUM_THREADS", + "-x", + "OMP_NUM_THREADS", + ] + np_flag = "-np" + elif self.mpi_env_mode == "mpich": + # Explicitly set env variables for all ranks + env_args = [ + "-env", + "PPC_NUM_THREADS", + self.__ppc_env["PPC_NUM_THREADS"], + "-env", + "OMP_NUM_THREADS", + self.__ppc_env["OMP_NUM_THREADS"], + ] + np_flag = "-n" + else: + # Unknown MPI flavor: rely on environment inheritance and default to -np + env_args = [] + np_flag = "-np" + + return base + env_args + [np_flag, ppc_num_proc] + @staticmethod def __get_gtest_settings(repeats_count, type_task): command = [ @@ -133,10 +219,7 @@ def run_processes(self, additional_mpi_args): raise EnvironmentError( "Required environment variable 'PPC_NUM_PROC' is not set." ) - - mpi_running = ( - [self.mpi_exec] + shlex.split(additional_mpi_args) + ["-np", ppc_num_proc] - ) + mpi_running = self.__build_mpi_cmd(ppc_num_proc, additional_mpi_args) if not self.__ppc_env.get("PPC_ASAN_RUN"): for task_type in ["all", "mpi"]: self.__run_exec( @@ -147,7 +230,7 @@ def run_processes(self, additional_mpi_args): def run_performance(self): if not self.__ppc_env.get("PPC_ASAN_RUN"): - mpi_running = [self.mpi_exec, "-np", self.__ppc_num_proc] + mpi_running = self.__build_mpi_cmd(self.__ppc_num_proc, "") for task_type in ["all", "mpi"]: self.__run_exec( mpi_running