Skip to content

Commit 32b0cb2

Browse files
committed
Add wrapper variant checking for identical results on all ranks
1 parent e1f46d6 commit 32b0cb2

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

testsuite/pytests/sli2py_mpi/mpi_test_wrapper.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,35 @@ def assert_correct_results(self, tmpdirpath):
305305
pd.testing.assert_frame_equal(res[0], r)
306306

307307

308+
class MPITestAssertAllRanksEqual(MPITestWrapper):
309+
"""
310+
Assert that the results from all ranks are equal, independent of number of ranks.
311+
"""
312+
313+
def assert_correct_results(self, tmpdirpath):
314+
self.collect_results(tmpdirpath)
315+
316+
all_res = []
317+
if self._spike:
318+
raise NotImplementedError("SPIKE data not supported by MPITestAssertAllRanksEqual")
319+
320+
if self._multi:
321+
raise NotImplementedError("MULTI data not supported by MPITestAssertAllRanksEqual")
322+
323+
if self._other:
324+
all_res = list(self._other.values()) # need to get away from dict_values to allow indexing below
325+
326+
assert len(all_res) == len(self._procs_lst), "Missing data for some process numbers"
327+
assert len(all_res[0]) == self._procs_lst[0], "Data for first proc number does not match number of procs"
328+
329+
reference = all_res[0][0]
330+
for res, num_ranks in zip(all_res, self._procs_lst):
331+
assert len(res) == num_ranks, f"Got data for {len(res)} ranks, expected {num_ranks}."
332+
333+
for r in res:
334+
pd.testing.assert_frame_equal(r, reference)
335+
336+
308337
class MPITestAssertCompletes(MPITestWrapper):
309338
"""
310339
Test class that just confirms that the test code completes.

0 commit comments

Comments
 (0)