Skip to content

Commit e401a89

Browse files
authored
Merge pull request #3630 from heplesser/test_global_rng
Port mpitest/test_global_rng from SLI to Python
2 parents 81a8084 + 17dcc6f commit e401a89

File tree

2 files changed

+79
-0
lines changed

2 files changed

+79
-0
lines changed

testsuite/pytests/sli2py_mpi/mpi_test_wrapper.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,35 @@ def assert_correct_results(self, tmpdirpath):
305305
pd.testing.assert_frame_equal(res[0], r)
306306

307307

308+
class MPITestAssertAllRanksEqual(MPITestWrapper):
309+
"""
310+
Assert that the results from all ranks are equal, independent of number of ranks.
311+
"""
312+
313+
def assert_correct_results(self, tmpdirpath):
314+
self.collect_results(tmpdirpath)
315+
316+
all_res = []
317+
if self._spike:
318+
raise NotImplementedError("SPIKE data not supported by MPITestAssertAllRanksEqual")
319+
320+
if self._multi:
321+
raise NotImplementedError("MULTI data not supported by MPITestAssertAllRanksEqual")
322+
323+
if self._other:
324+
all_res = list(self._other.values()) # need to get away from dict_values to allow indexing below
325+
326+
assert len(all_res) == len(self._procs_lst), "Missing data for some process numbers"
327+
assert len(all_res[0]) == self._procs_lst[0], "Data for first proc number does not match number of procs"
328+
329+
reference = all_res[0][0]
330+
for res, num_ranks in zip(all_res, self._procs_lst):
331+
assert len(res) == num_ranks, f"Got data for {len(res)} ranks, expected {num_ranks}."
332+
333+
for r in res:
334+
pd.testing.assert_frame_equal(r, reference)
335+
336+
308337
class MPITestAssertCompletes(MPITestWrapper):
309338
"""
310339
Test class that just confirms that the test code completes.
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# -*- coding: utf-8 -*-
2+
#
3+
# test_global_rng.py
4+
#
5+
# This file is part of NEST.
6+
#
7+
# Copyright (C) 2004 The NEST Initiative
8+
#
9+
# NEST is free software: you can redistribute it and/or modify
10+
# it under the terms of the GNU General Public License as published by
11+
# the Free Software Foundation, either version 2 of the License, or
12+
# (at your option) any later version.
13+
#
14+
# NEST is distributed in the hope that it will be useful,
15+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
16+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17+
# GNU General Public License for more details.
18+
#
19+
# You should have received a copy of the GNU General Public License
20+
# along with NEST. If not, see <http://www.gnu.org/licenses/>.
21+
22+
import pandas as pd
23+
import pytest
24+
from mpi_test_wrapper import MPITestAssertAllRanksEqual
25+
26+
27+
# Parametrization over the number of nodes here only to show hat it works
28+
@pytest.mark.skipif_incompatible_mpi
29+
@MPITestAssertAllRanksEqual([1, 2, 4], debug=False)
30+
def test_global_rng():
31+
"""
32+
Confirm that NEST random parameter used from the Python level uses globally sync'ed RNG correctly.
33+
All ranks must report identical random number sequences independent of the number of ranks.
34+
35+
The test compares connection data written to OTHER_LABEL.
36+
"""
37+
38+
import nest
39+
40+
nest.rng_seed = 12
41+
p = nest.CreateParameter("uniform", {"min": 0, "max": 1})
42+
43+
# Uncomment one of the two for loops to provoke failure
44+
# for _ in range(nest.num_processes):
45+
# p.GetValue()
46+
# for _ in range(nest.Rank()):
47+
# p.GetValue()
48+
49+
vals = pd.DataFrame([p.GetValue() for _ in range(5)])
50+
vals.to_csv(OTHER_LABEL.format(nest.num_processes, nest.Rank()), sep="\t") # noqa: F821

0 commit comments

Comments
 (0)