Skip to content

Commit 4b2164d

Browse files
committed
Add support for test-specific assert functions
1 parent 3bc2042 commit 4b2164d

File tree

1 file changed

+39
-28
lines changed

1 file changed

+39
-28
lines changed

testsuite/pytests/sli2py_mpi/mpi_test_wrapper.py

Lines changed: 39 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@
5151
They must be used as `label` for spike recorders and multimeters, respectively,
5252
or for other files for output data (TAB-separated CSV files). They are format
5353
strings expecting the number of processes with which NEST is run as argument.
54+
- A function taking a single argument ``all_res`` and performing assertions on it
55+
can be passed to the constructor of the wrapper. This allows test-specific additional
56+
checks on the collected data from all MPI processes without requiring MPI4Py.
5457
- Set `debug=True` on the decorator to see debug output and keep the
5558
temporary directory that has been created (latter works only in
5659
Python 3.12 and later)
@@ -126,7 +129,13 @@ class MPITestWrapper:
126129
"""
127130
)
128131

129-
def __init__(self, procs_lst, debug=False):
132+
def __init__(self, procs_lst, debug=False, specific_assert=None):
133+
"""
134+
procs_lst : list of number of process to run tests for, e.g., [1, 2, 4]
135+
debug : if True, provide output during execution and do not delete temp directory (Python >=3.12)
136+
specific_assert : function taking ``all_res`` as input and performing test-specific assertions after the overall check performed by the wrapper class
137+
"""
138+
130139
try:
131140
iter(procs_lst)
132141
except TypeError:
@@ -137,6 +146,7 @@ def __init__(self, procs_lst, debug=False):
137146
self._spike = None
138147
self._multi = None
139148
self._other = None
149+
self._specific_assert = specific_assert
140150

141151
@staticmethod
142152
def _pure_test_func(func):
@@ -292,33 +302,29 @@ class MPITestAssertEqual(MPITestWrapper):
292302
def assert_correct_results(self, tmpdirpath):
293303
self.collect_results(tmpdirpath)
294304

295-
all_res = []
305+
all_res = {}
296306
if self._spike:
297307
# For each number of procs, combine results across VPs and sort by time and sender
298308

299309
# Include only frames containing at least one non-nan value so pandas knows datatypes.
300310
# .all() returns True for empty arrays.
301-
all_res.append(
302-
[
303-
pd.concat(self._drop_empty_dataframes(spikes), ignore_index=True).sort_values(
304-
by=["time_step", "time_offset", "sender"], ignore_index=True
305-
)
306-
for spikes in self._spike.values()
307-
]
308-
)
311+
all_res["spike"] = [
312+
pd.concat(self._drop_empty_dataframes(spikes), ignore_index=True).sort_values(
313+
by=["time_step", "time_offset", "sender"], ignore_index=True
314+
)
315+
for spikes in self._spike.values()
316+
]
309317

310318
if self._multi:
311319
# For each number of procs, combine results across VPs and sort by time and sender
312320
# Include only frames containing at least one non-nan value so pandas knows datatypes.
313321
# .all() returns True for empty arrays.
314-
all_res.append(
315-
[
316-
pd.concat(self._drop_empty_dataframes(mmdata), ignore_index=True).sort_values(
317-
by=["time_step", "time_offset", "sender"], ignore_index=True
318-
)
319-
for mmdata in self._multi.values()
320-
]
321-
)
322+
all_res["multi"] = [
323+
pd.concat(self._drop_empty_dataframes(mmdata), ignore_index=True).sort_values(
324+
by=["time_step", "time_offset", "sender"], ignore_index=True
325+
)
326+
for mmdata in self._multi.values()
327+
]
322328

323329
if self._other:
324330
# For each number of procs, combine across ranks or VPs (depends on what test has written) and
@@ -331,22 +337,23 @@ def assert_correct_results(self, tmpdirpath):
331337
# [0] then picks the first DataFrame from that list
332338
# columns need to be converted to list() to be passed to sort_values()
333339
all_columns = list(next(iter(self._other.values()))[0].columns)
334-
all_res.append(
335-
[
336-
pd.concat(self._drop_empty_dataframes(others), ignore_index=True).sort_values(
337-
by=all_columns, ignore_index=True
338-
)
339-
for others in self._other.values()
340-
]
341-
)
340+
all_res["other"] = [
341+
pd.concat(self._drop_empty_dataframes(others), ignore_index=True).sort_values(
342+
by=all_columns, ignore_index=True
343+
)
344+
for others in self._other.values()
345+
]
342346

343347
assert all_res, "No test data collected"
344-
for res in all_res:
348+
for res in all_res.values():
345349
assert len(res) == len(self._procs_lst), "Could not collect data for all procs"
346350

347351
for r in res[1:]:
348352
pd.testing.assert_frame_equal(res[0], r)
349353

354+
if self._specific_assert:
355+
self._specific_assert(all_res)
356+
350357

351358
class MPITestAssertAllRanksEqual(MPITestWrapper):
352359
"""
@@ -376,6 +383,9 @@ def assert_correct_results(self, tmpdirpath):
376383
for r in res:
377384
pd.testing.assert_frame_equal(r, reference)
378385

386+
if self._specific_assert:
387+
self._specific_assert(all_res)
388+
379389

380390
class MPITestAssertCompletes(MPITestWrapper):
381391
"""
@@ -385,4 +395,5 @@ class MPITestAssertCompletes(MPITestWrapper):
385395
"""
386396

387397
def assert_correct_results(self, tmpdirpath):
388-
pass
398+
if self._specific_assert:
399+
self._specific_assert(all_res)

0 commit comments

Comments
 (0)