Skip to content

Commit ebcf651

Browse files
authored
Merge pull request nest#3644 from heplesser/test_sin_gen_56
Port tests for identical and individual spike trains from poisson generator from SLI to Py
2 parents 7c66db5 + 425066e commit ebcf651

File tree

5 files changed

+209
-249
lines changed

5 files changed

+209
-249
lines changed

testsuite/mpitests/test_sinusoidal_poisson_generator_5.sli

Lines changed: 0 additions & 104 deletions
This file was deleted.

testsuite/mpitests/test_sinusoidal_poisson_generator_6.sli

Lines changed: 0 additions & 117 deletions
This file was deleted.

testsuite/pytests/sli2py_mpi/mpi_test_wrapper.py

Lines changed: 40 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@
5151
They must be used as `label` for spike recorders and multimeters, respectively,
5252
or for other files for output data (TAB-separated CSV files). They are format
5353
strings expecting the number of processes with which NEST is run as argument.
54+
- A function taking a single argument ``all_res`` and performing assertions on it
55+
can be passed to the constructor of the wrapper. This allows test-specific additional
56+
checks on the collected data from all MPI processes without requiring MPI4Py.
5457
- Set `debug=True` on the decorator to see debug output and keep the
5558
temporary directory that has been created (latter works only in
5659
Python 3.12 and later)
@@ -126,7 +129,14 @@ class MPITestWrapper:
126129
"""
127130
)
128131

129-
def __init__(self, procs_lst, debug=False):
132+
def __init__(self, procs_lst, debug=False, specific_assert=None):
133+
"""
134+
procs_lst : list of number of process to run tests for, e.g., [1, 2, 4]
135+
debug : if True, provide output during execution and do not delete temp directory (Python >=3.12)
136+
specific_assert : function taking ``all_res`` as input and performing test-specific assertions
137+
after the overall check performed by the wrapper class
138+
"""
139+
130140
try:
131141
iter(procs_lst)
132142
except TypeError:
@@ -137,6 +147,7 @@ def __init__(self, procs_lst, debug=False):
137147
self._spike = None
138148
self._multi = None
139149
self._other = None
150+
self._specific_assert = specific_assert
140151

141152
@staticmethod
142153
def _pure_test_func(func):
@@ -292,33 +303,29 @@ class MPITestAssertEqual(MPITestWrapper):
292303
def assert_correct_results(self, tmpdirpath):
293304
self.collect_results(tmpdirpath)
294305

295-
all_res = []
306+
all_res = {}
296307
if self._spike:
297308
# For each number of procs, combine results across VPs and sort by time and sender
298309

299310
# Include only frames containing at least one non-nan value so pandas knows datatypes.
300311
# .all() returns True for empty arrays.
301-
all_res.append(
302-
[
303-
pd.concat(self._drop_empty_dataframes(spikes), ignore_index=True).sort_values(
304-
by=["time_step", "time_offset", "sender"], ignore_index=True
305-
)
306-
for spikes in self._spike.values()
307-
]
308-
)
312+
all_res["spike"] = [
313+
pd.concat(self._drop_empty_dataframes(spikes), ignore_index=True).sort_values(
314+
by=["time_step", "time_offset", "sender"], ignore_index=True
315+
)
316+
for spikes in self._spike.values()
317+
]
309318

310319
if self._multi:
311320
# For each number of procs, combine results across VPs and sort by time and sender
312321
# Include only frames containing at least one non-nan value so pandas knows datatypes.
313322
# .all() returns True for empty arrays.
314-
all_res.append(
315-
[
316-
pd.concat(self._drop_empty_dataframes(mmdata), ignore_index=True).sort_values(
317-
by=["time_step", "time_offset", "sender"], ignore_index=True
318-
)
319-
for mmdata in self._multi.values()
320-
]
321-
)
323+
all_res["multi"] = [
324+
pd.concat(self._drop_empty_dataframes(mmdata), ignore_index=True).sort_values(
325+
by=["time_step", "time_offset", "sender"], ignore_index=True
326+
)
327+
for mmdata in self._multi.values()
328+
]
322329

323330
if self._other:
324331
# For each number of procs, combine across ranks or VPs (depends on what test has written) and
@@ -331,22 +338,24 @@ def assert_correct_results(self, tmpdirpath):
331338
# [0] then picks the first DataFrame from that list
332339
# columns need to be converted to list() to be passed to sort_values()
333340
all_columns = list(next(iter(self._other.values()))[0].columns)
334-
all_res.append(
335-
[
336-
pd.concat(self._drop_empty_dataframes(others), ignore_index=True).sort_values(
337-
by=all_columns, ignore_index=True
338-
)
339-
for others in self._other.values()
340-
]
341-
)
341+
all_res["other"] = [
342+
pd.concat(self._drop_empty_dataframes(others), ignore_index=True).sort_values(
343+
by=all_columns, ignore_index=True
344+
)
345+
for others in self._other.values()
346+
]
342347

343348
assert all_res, "No test data collected"
344-
for res in all_res:
349+
350+
for res in all_res.values():
345351
assert len(res) == len(self._procs_lst), "Could not collect data for all procs"
346352

347353
for r in res[1:]:
348354
pd.testing.assert_frame_equal(res[0], r)
349355

356+
if self._specific_assert:
357+
self._specific_assert(all_res)
358+
350359

351360
class MPITestAssertAllRanksEqual(MPITestWrapper):
352361
"""
@@ -376,6 +385,9 @@ def assert_correct_results(self, tmpdirpath):
376385
for r in res:
377386
pd.testing.assert_frame_equal(r, reference)
378387

388+
if self._specific_assert:
389+
self._specific_assert(all_res)
390+
379391

380392
class MPITestAssertCompletes(MPITestWrapper):
381393
"""
@@ -385,4 +397,4 @@ class MPITestAssertCompletes(MPITestWrapper):
385397
"""
386398

387399
def assert_correct_results(self, tmpdirpath):
388-
pass
400+
assert self._specific_assert is None, "MPITestAssertCompletes does not support specific_assert."

0 commit comments

Comments
 (0)