Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 67 additions & 21 deletions modules/runners/src/runners.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
#include <chrono>
#include <cstdint>
#include <cstdlib>
#include <exception>
#include <format>
#include <iostream>
#include <memory>
#include <random>
#include <stdexcept>
#include <string>
#include <string_view>

#include "oneapi/tbb/global_control.h"
#include "util/include/util.hpp"
Expand Down Expand Up @@ -51,6 +53,8 @@ void WorkerTestFailurePrinter::OnTestEnd(const ::testing::TestInfo &test_info) {
}
PrintProcessRank();
base_->OnTestEnd(test_info);
// Abort the whole MPI job on any test failure to avoid other ranks hanging on barriers.
MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE);
}

void WorkerTestFailurePrinter::OnTestPartResult(const ::testing::TestPartResult &test_part_result) {
Expand All @@ -76,6 +80,63 @@ int RunAllTests() {
}
return status;
}

void SyncGTestSeed() {
unsigned int seed = 0;
int rank = -1;
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
if (rank == 0) {
try {
seed = std::random_device{}();
} catch (...) {
seed = 0;
}
if (seed == 0) {
const auto now = static_cast<std::uint64_t>(std::chrono::steady_clock::now().time_since_epoch().count());
seed = static_cast<unsigned int>(((now & 0x7fffffffULL) | 1ULL));
}
}
MPI_Bcast(&seed, 1, MPI_UNSIGNED, 0, MPI_COMM_WORLD);
::testing::GTEST_FLAG(random_seed) = static_cast<int>(seed);
}

void SyncGTestFilter() {
int rank = -1;
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
std::string filter = (rank == 0) ? ::testing::GTEST_FLAG(filter) : std::string{};
int len = static_cast<int>(filter.size());
MPI_Bcast(&len, 1, MPI_INT, 0, MPI_COMM_WORLD);
if (rank != 0) {
filter.resize(static_cast<std::size_t>(len));
}
if (len > 0) {
MPI_Bcast(filter.data(), len, MPI_CHAR, 0, MPI_COMM_WORLD);
}
::testing::GTEST_FLAG(filter) = filter;
}

bool HasFlag(int argc, char **argv, std::string_view flag) {
for (int i = 1; i < argc; ++i) {
if (argv[i] != nullptr && std::string_view(argv[i]) == flag) {
return true;
}
}
return false;
}

int RunAllTestsSafely() {
try {
return RunAllTests();
} catch (const std::exception &e) {
std::cerr << std::format("[ ERROR ] Exception after tests: {}", e.what()) << '\n';
MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE);
return EXIT_FAILURE;
} catch (...) {
std::cerr << "[ ERROR ] Unknown exception after tests" << '\n';
MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE);
return EXIT_FAILURE;
}
}
} // namespace

int Init(int argc, char **argv) {
Expand All @@ -91,36 +152,21 @@ int Init(int argc, char **argv) {

::testing::InitGoogleTest(&argc, argv);

// Ensure consistent GoogleTest shuffle order across all MPI ranks.
unsigned int seed = 0;
int rank_for_seed = -1;
MPI_Comm_rank(MPI_COMM_WORLD, &rank_for_seed);

if (rank_for_seed == 0) {
try {
seed = std::random_device{}();
} catch (...) {
seed = 0;
}
if (seed == 0) {
const auto now = static_cast<std::uint64_t>(std::chrono::steady_clock::now().time_since_epoch().count());
seed = static_cast<unsigned int>(((now & 0x7fffffffULL) | 1ULL));
}
}

MPI_Bcast(&seed, 1, MPI_UNSIGNED, 0, MPI_COMM_WORLD);
::testing::GTEST_FLAG(random_seed) = static_cast<int>(seed);
// Synchronize GoogleTest internals across ranks to avoid divergence
SyncGTestSeed();
SyncGTestFilter();

auto &listeners = ::testing::UnitTest::GetInstance()->listeners();
int rank = -1;
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
if (rank != 0 && (argc < 2 || argv[1] != std::string("--print-workers"))) {
const bool print_workers = HasFlag(argc, argv, "--print-workers");
if (rank != 0 && !print_workers) {
auto *listener = listeners.Release(listeners.default_result_printer());
listeners.Append(new WorkerTestFailurePrinter(std::shared_ptr<::testing::TestEventListener>(listener)));
}
listeners.Append(new UnreadMessagesDetector());

auto status = RunAllTests();
const int status = RunAllTestsSafely();

const int finalize_res = MPI_Finalize();
if (finalize_res != MPI_SUCCESS) {
Expand Down
16 changes: 15 additions & 1 deletion modules/util/include/util.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include <algorithm>
#include <array>
#include <atomic>
#include <cctype>
#include <cstdint>
Expand All @@ -26,6 +27,7 @@
#include <gtest/gtest.h>

#include <libenvpp/detail/environment.hpp>
#include <libenvpp/detail/get.hpp>
#include <nlohmann/json.hpp>

/// @brief JSON namespace used for settings and config parsing.
Expand Down Expand Up @@ -123,7 +125,19 @@ class ScopedPerTestEnv {
private:
static std::string CreateTmpDir(const std::string &token) {
namespace fs = std::filesystem;
const fs::path tmp = fs::temp_directory_path() / (std::string("ppc_test_") + token);
auto make_rank_suffix = []() -> std::string {
// Derive rank from common MPI env vars without including MPI headers
constexpr std::array<std::string_view, 5> kRankVars = {"OMPI_COMM_WORLD_RANK", "PMI_RANK", "PMIX_RANK",
"SLURM_PROCID", "MSMPI_RANK"};
for (auto name : kRankVars) {
if (auto r = env::get<int>(name); r.has_value() && r.value() >= 0) {
return std::string("_rank_") + std::to_string(r.value());
}
}
return std::string{};
};
const std::string rank_suffix = IsUnderMpirun() ? make_rank_suffix() : std::string{};
const fs::path tmp = fs::temp_directory_path() / (std::string("ppc_test_") + token + rank_suffix);
std::error_code ec;
fs::create_directories(tmp, ec);
(void)ec;
Expand Down
93 changes: 88 additions & 5 deletions scripts/run_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,17 @@ def __init__(self, verbose=False):
self.mpi_exec = "mpiexec"
else:
self.mpi_exec = "mpirun"
self.platform = platform.system()

# Detect MPI implementation to choose compatible flags
self.mpi_env_mode = "unknown" # one of: openmpi, mpich, unknown
self.mpi_np_flag = "-np"
if self.platform == "Windows":
# MSMPI uses -env and -n
self.mpi_env_mode = "mpich"
self.mpi_np_flag = "-n"
else:
self.mpi_env_mode, self.mpi_np_flag = self.__detect_mpi_impl()

@staticmethod
def __get_project_path():
Expand Down Expand Up @@ -88,6 +99,81 @@ def __run_exec(self, command):
if result.returncode != 0:
raise Exception(f"Subprocess return {result.returncode}.")

def __detect_mpi_impl(self):
"""Detect MPI implementation and return (env_mode, np_flag).
env_mode: 'openmpi' -> use '-x VAR', 'mpich' -> use '-genvlist VAR1,VAR2', 'unknown' -> pass no env flags.
np_flag: '-np' for OpenMPI/unknown, '-n' for MPICH-family.
"""
probes = (["--version"], ["-V"], ["-v"], ["--help"], ["-help"])
out = ""
for args in probes:
try:
proc = subprocess.run(
[self.mpi_exec] + list(args),
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
)
out = (proc.stdout or "").lower()
if out:
break
except Exception:
continue

if "open mpi" in out or "ompi" in out:
return "openmpi", "-np"
if (
"hydra" in out
or "mpich" in out
or "intel(r) mpi" in out
or "intel mpi" in out
):
return "mpich", "-n"
return "unknown", "-np"

def __build_mpi_cmd(self, ppc_num_proc, additional_mpi_args):
base = [self.mpi_exec] + shlex.split(additional_mpi_args)

if self.platform == "Windows":
# MS-MPI style
env_args = [
"-env",
"PPC_NUM_THREADS",
self.__ppc_env["PPC_NUM_THREADS"],
"-env",
"OMP_NUM_THREADS",
self.__ppc_env["OMP_NUM_THREADS"],
]
np_args = ["-n", ppc_num_proc]
return base + env_args + np_args

# Non-Windows
if self.mpi_env_mode == "openmpi":
env_args = [
"-x",
"PPC_NUM_THREADS",
"-x",
"OMP_NUM_THREADS",
]
np_flag = "-np"
elif self.mpi_env_mode == "mpich":
# Explicitly set env variables for all ranks
env_args = [
"-env",
"PPC_NUM_THREADS",
self.__ppc_env["PPC_NUM_THREADS"],
"-env",
"OMP_NUM_THREADS",
self.__ppc_env["OMP_NUM_THREADS"],
]
np_flag = "-n"
else:
# Unknown MPI flavor: rely on environment inheritance and default to -np
env_args = []
np_flag = "-np"

return base + env_args + [np_flag, ppc_num_proc]

@staticmethod
def __get_gtest_settings(repeats_count, type_task):
command = [
Expand Down Expand Up @@ -133,10 +219,7 @@ def run_processes(self, additional_mpi_args):
raise EnvironmentError(
"Required environment variable 'PPC_NUM_PROC' is not set."
)

mpi_running = (
[self.mpi_exec] + shlex.split(additional_mpi_args) + ["-np", ppc_num_proc]
)
mpi_running = self.__build_mpi_cmd(ppc_num_proc, additional_mpi_args)
if not self.__ppc_env.get("PPC_ASAN_RUN"):
for task_type in ["all", "mpi"]:
self.__run_exec(
Expand All @@ -147,7 +230,7 @@ def run_processes(self, additional_mpi_args):

def run_performance(self):
if not self.__ppc_env.get("PPC_ASAN_RUN"):
mpi_running = [self.mpi_exec, "-np", self.__ppc_num_proc]
mpi_running = self.__build_mpi_cmd(self.__ppc_num_proc, "")
for task_type in ["all", "mpi"]:
self.__run_exec(
mpi_running
Expand Down
Loading