5151 They must be used as `label` for spike recorders and multimeters, respectively,
5252 or for other files for output data (TAB-separated CSV files). They are format
5353 strings expecting the number of processes with which NEST is run as argument.
54+ - A function taking a single argument ``all_res`` and performing assertions on it
55+ can be passed to the constructor of the wrapper. This allows test-specific additional
56+ checks on the collected data from all MPI processes without requiring MPI4Py.
5457- Set `debug=True` on the decorator to see debug output and keep the
5558 temporary directory that has been created (latter works only in
5659 Python 3.12 and later)
@@ -126,7 +129,13 @@ class MPITestWrapper:
126129 """
127130 )
128131
129- def __init__ (self , procs_lst , debug = False ):
132+ def __init__ (self , procs_lst , debug = False , specific_assert = None ):
133+ """
134+ procs_lst : list of number of process to run tests for, e.g., [1, 2, 4]
135+ debug : if True, provide output during execution and do not delete temp directory (Python >=3.12)
136+ specific_assert : function taking ``all_res`` as input and performing test-specific assertions after the overall check performed by the wrapper class
137+ """
138+
130139 try :
131140 iter (procs_lst )
132141 except TypeError :
@@ -137,6 +146,7 @@ def __init__(self, procs_lst, debug=False):
137146 self ._spike = None
138147 self ._multi = None
139148 self ._other = None
149+ self ._specific_assert = specific_assert
140150
141151 @staticmethod
142152 def _pure_test_func (func ):
@@ -292,33 +302,29 @@ class MPITestAssertEqual(MPITestWrapper):
292302 def assert_correct_results (self , tmpdirpath ):
293303 self .collect_results (tmpdirpath )
294304
295- all_res = []
305+ all_res = {}
296306 if self ._spike :
297307 # For each number of procs, combine results across VPs and sort by time and sender
298308
299309 # Include only frames containing at least one non-nan value so pandas knows datatypes.
300310 # .all() returns True for empty arrays.
301- all_res .append (
302- [
303- pd .concat (self ._drop_empty_dataframes (spikes ), ignore_index = True ).sort_values (
304- by = ["time_step" , "time_offset" , "sender" ], ignore_index = True
305- )
306- for spikes in self ._spike .values ()
307- ]
308- )
311+ all_res ["spike" ] = [
312+ pd .concat (self ._drop_empty_dataframes (spikes ), ignore_index = True ).sort_values (
313+ by = ["time_step" , "time_offset" , "sender" ], ignore_index = True
314+ )
315+ for spikes in self ._spike .values ()
316+ ]
309317
310318 if self ._multi :
311319 # For each number of procs, combine results across VPs and sort by time and sender
312320 # Include only frames containing at least one non-nan value so pandas knows datatypes.
313321 # .all() returns True for empty arrays.
314- all_res .append (
315- [
316- pd .concat (self ._drop_empty_dataframes (mmdata ), ignore_index = True ).sort_values (
317- by = ["time_step" , "time_offset" , "sender" ], ignore_index = True
318- )
319- for mmdata in self ._multi .values ()
320- ]
321- )
322+ all_res ["multi" ] = [
323+ pd .concat (self ._drop_empty_dataframes (mmdata ), ignore_index = True ).sort_values (
324+ by = ["time_step" , "time_offset" , "sender" ], ignore_index = True
325+ )
326+ for mmdata in self ._multi .values ()
327+ ]
322328
323329 if self ._other :
324330 # For each number of procs, combine across ranks or VPs (depends on what test has written) and
@@ -331,22 +337,23 @@ def assert_correct_results(self, tmpdirpath):
331337 # [0] then picks the first DataFrame from that list
332338 # columns need to be converted to list() to be passed to sort_values()
333339 all_columns = list (next (iter (self ._other .values ()))[0 ].columns )
334- all_res .append (
335- [
336- pd .concat (self ._drop_empty_dataframes (others ), ignore_index = True ).sort_values (
337- by = all_columns , ignore_index = True
338- )
339- for others in self ._other .values ()
340- ]
341- )
340+ all_res ["other" ] = [
341+ pd .concat (self ._drop_empty_dataframes (others ), ignore_index = True ).sort_values (
342+ by = all_columns , ignore_index = True
343+ )
344+ for others in self ._other .values ()
345+ ]
342346
343347 assert all_res , "No test data collected"
344- for res in all_res :
348+ for res in all_res . values () :
345349 assert len (res ) == len (self ._procs_lst ), "Could not collect data for all procs"
346350
347351 for r in res [1 :]:
348352 pd .testing .assert_frame_equal (res [0 ], r )
349353
354+ if self ._specific_assert :
355+ self ._specific_assert (all_res )
356+
350357
351358class MPITestAssertAllRanksEqual (MPITestWrapper ):
352359 """
@@ -376,6 +383,9 @@ def assert_correct_results(self, tmpdirpath):
376383 for r in res :
377384 pd .testing .assert_frame_equal (r , reference )
378385
386+ if self ._specific_assert :
387+ self ._specific_assert (all_res )
388+
379389
380390class MPITestAssertCompletes (MPITestWrapper ):
381391 """
@@ -385,4 +395,5 @@ class MPITestAssertCompletes(MPITestWrapper):
385395 """
386396
387397 def assert_correct_results (self , tmpdirpath ):
388- pass
398+ if self ._specific_assert :
399+ self ._specific_assert (all_res )
0 commit comments