Skip to content

Commit 9de5c57

Browse files
committed
Port test_global_rng from SLI to Py
1 parent 32b0cb2 commit 9de5c57

File tree

1 file changed

+50
-0
lines changed

1 file changed

+50
-0
lines changed
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# -*- coding: utf-8 -*-
2+
#
3+
# test_global_rng.py
4+
#
5+
# This file is part of NEST.
6+
#
7+
# Copyright (C) 2004 The NEST Initiative
8+
#
9+
# NEST is free software: you can redistribute it and/or modify
10+
# it under the terms of the GNU General Public License as published by
11+
# the Free Software Foundation, either version 2 of the License, or
12+
# (at your option) any later version.
13+
#
14+
# NEST is distributed in the hope that it will be useful,
15+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
16+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17+
# GNU General Public License for more details.
18+
#
19+
# You should have received a copy of the GNU General Public License
20+
# along with NEST. If not, see <http://www.gnu.org/licenses/>.
21+
22+
import pandas as pd
23+
import pytest
24+
from mpi_test_wrapper import MPITestAssertAllRanksEqual
25+
26+
27+
# Parametrization over the number of nodes here only to show hat it works
28+
@pytest.mark.skipif_incompatible_mpi
29+
@MPITestAssertAllRanksEqual([1, 2, 4], debug=False)
30+
def test_global_rng():
31+
"""
32+
Confirm that NEST random parameter used from the Python level uses globally sync'ed RNG correctly.
33+
All ranks must report identical random number sequences independent of the number of ranks.
34+
35+
The test compares connection data written to OTHER_LABEL.
36+
"""
37+
38+
import nest
39+
40+
nest.rng_seed = 12
41+
p = nest.CreateParameter("uniform", {"min": 0, "max": 1})
42+
43+
# Uncomment one of the two for loops to provoke failure
44+
# for _ in range(nest.num_processes):
45+
# p.GetValue()
46+
# for _ in range(nest.Rank()):
47+
# p.GetValue()
48+
49+
vals = pd.DataFrame([p.GetValue() for _ in range(5)])
50+
vals.to_csv(OTHER_LABEL.format(nest.num_processes, nest.Rank()), sep="\t") # noqa: F821

0 commit comments

Comments
 (0)