5151 They must be used as `label` for spike recorders and multimeters, respectively,
5252 or for other files for output data (TAB-separated CSV files). They are format
5353 strings expecting the number of processes with which NEST is run as argument.
54+ - A function taking a single argument ``all_res`` and performing assertions on it
55+ can be passed to the constructor of the wrapper. This allows test-specific additional
56+ checks on the collected data from all MPI processes without requiring MPI4Py.
5457- Set `debug=True` on the decorator to see debug output and keep the
5558 temporary directory that has been created (latter works only in
5659 Python 3.12 and later)
@@ -126,7 +129,14 @@ class MPITestWrapper:
126129 """
127130 )
128131
129- def __init__ (self , procs_lst , debug = False ):
132+ def __init__ (self , procs_lst , debug = False , specific_assert = None ):
133+ """
134+ procs_lst : list of number of process to run tests for, e.g., [1, 2, 4]
135+ debug : if True, provide output during execution and do not delete temp directory (Python >=3.12)
136+ specific_assert : function taking ``all_res`` as input and performing test-specific assertions
137+ after the overall check performed by the wrapper class
138+ """
139+
130140 try :
131141 iter (procs_lst )
132142 except TypeError :
@@ -137,6 +147,7 @@ def __init__(self, procs_lst, debug=False):
137147 self ._spike = None
138148 self ._multi = None
139149 self ._other = None
150+ self ._specific_assert = specific_assert
140151
141152 @staticmethod
142153 def _pure_test_func (func ):
@@ -292,33 +303,29 @@ class MPITestAssertEqual(MPITestWrapper):
292303 def assert_correct_results (self , tmpdirpath ):
293304 self .collect_results (tmpdirpath )
294305
295- all_res = []
306+ all_res = {}
296307 if self ._spike :
297308 # For each number of procs, combine results across VPs and sort by time and sender
298309
299310 # Include only frames containing at least one non-nan value so pandas knows datatypes.
300311 # .all() returns True for empty arrays.
301- all_res .append (
302- [
303- pd .concat (self ._drop_empty_dataframes (spikes ), ignore_index = True ).sort_values (
304- by = ["time_step" , "time_offset" , "sender" ], ignore_index = True
305- )
306- for spikes in self ._spike .values ()
307- ]
308- )
312+ all_res ["spike" ] = [
313+ pd .concat (self ._drop_empty_dataframes (spikes ), ignore_index = True ).sort_values (
314+ by = ["time_step" , "time_offset" , "sender" ], ignore_index = True
315+ )
316+ for spikes in self ._spike .values ()
317+ ]
309318
310319 if self ._multi :
311320 # For each number of procs, combine results across VPs and sort by time and sender
312321 # Include only frames containing at least one non-nan value so pandas knows datatypes.
313322 # .all() returns True for empty arrays.
314- all_res .append (
315- [
316- pd .concat (self ._drop_empty_dataframes (mmdata ), ignore_index = True ).sort_values (
317- by = ["time_step" , "time_offset" , "sender" ], ignore_index = True
318- )
319- for mmdata in self ._multi .values ()
320- ]
321- )
323+ all_res ["multi" ] = [
324+ pd .concat (self ._drop_empty_dataframes (mmdata ), ignore_index = True ).sort_values (
325+ by = ["time_step" , "time_offset" , "sender" ], ignore_index = True
326+ )
327+ for mmdata in self ._multi .values ()
328+ ]
322329
323330 if self ._other :
324331 # For each number of procs, combine across ranks or VPs (depends on what test has written) and
@@ -331,22 +338,24 @@ def assert_correct_results(self, tmpdirpath):
331338 # [0] then picks the first DataFrame from that list
332339 # columns need to be converted to list() to be passed to sort_values()
333340 all_columns = list (next (iter (self ._other .values ()))[0 ].columns )
334- all_res .append (
335- [
336- pd .concat (self ._drop_empty_dataframes (others ), ignore_index = True ).sort_values (
337- by = all_columns , ignore_index = True
338- )
339- for others in self ._other .values ()
340- ]
341- )
341+ all_res ["other" ] = [
342+ pd .concat (self ._drop_empty_dataframes (others ), ignore_index = True ).sort_values (
343+ by = all_columns , ignore_index = True
344+ )
345+ for others in self ._other .values ()
346+ ]
342347
343348 assert all_res , "No test data collected"
344- for res in all_res :
349+
350+ for res in all_res .values ():
345351 assert len (res ) == len (self ._procs_lst ), "Could not collect data for all procs"
346352
347353 for r in res [1 :]:
348354 pd .testing .assert_frame_equal (res [0 ], r )
349355
356+ if self ._specific_assert :
357+ self ._specific_assert (all_res )
358+
350359
351360class MPITestAssertAllRanksEqual (MPITestWrapper ):
352361 """
@@ -376,6 +385,9 @@ def assert_correct_results(self, tmpdirpath):
376385 for r in res :
377386 pd .testing .assert_frame_equal (r , reference )
378387
388+ if self ._specific_assert :
389+ self ._specific_assert (all_res )
390+
379391
380392class MPITestAssertCompletes (MPITestWrapper ):
381393 """
@@ -385,4 +397,4 @@ class MPITestAssertCompletes(MPITestWrapper):
385397 """
386398
387399 def assert_correct_results (self , tmpdirpath ):
388- pass
400+ assert self . _specific_assert is None , "MPITestAssertCompletes does not support specific_assert."
0 commit comments