Skip to content

Commit be6d058

Browse files
committed
Port tests for identical and individual spike trains from poisson generator from SLI to Py
1 parent 4b2164d commit be6d058

File tree

4 files changed

+169
-221
lines changed

4 files changed

+169
-221
lines changed

testsuite/mpitests/test_sinusoidal_poisson_generator_5.sli

Lines changed: 0 additions & 104 deletions
This file was deleted.

testsuite/mpitests/test_sinusoidal_poisson_generator_6.sli

Lines changed: 0 additions & 117 deletions
This file was deleted.
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# -*- coding: utf-8 -*-
2+
#
3+
# test_sinusoidal_generators_parallel_individual.py
4+
#
5+
# This file is part of NEST.
6+
#
7+
# Copyright (C) 2004 The NEST Initiative
8+
#
9+
# NEST is free software: you can redistribute it and/or modify
10+
# it under the terms of the GNU General Public License as published by
11+
# the Free Software Foundation, either version 2 of the License, or
12+
# (at your option) any later version.
13+
#
14+
# NEST is distributed in the hope that it will be useful,
15+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
16+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17+
# GNU General Public License for more details.
18+
#
19+
# You should have received a copy of the GNU General Public License
20+
# along with NEST. If not, see <http://www.gnu.org/licenses/>.
21+
22+
"""
23+
Test that of sinusoidal generators work correctly also in parallel.
24+
25+
This file tests the case that all targets receive individual spike trains.
26+
We cannot parametrize over the checker function, therefore we need two files.
27+
"""
28+
29+
import nest
30+
import numpy as np
31+
import pytest
32+
from mpi_test_wrapper import MPITestAssertEqual
33+
34+
35+
def assert_all_spike_trains_different(all_res):
36+
"""Assert that for each number of processes, all spike trains are pairwise different."""
37+
38+
for spikes in all_res["spike"]:
39+
ix_by_sender = list(spikes.groupby("sender").groups.values())
40+
assert all(
41+
not np.array_equal(spikes.iloc[lhs_ix].time_step, spikes.iloc[rhs_ix].time_step)
42+
for idx, lhs_ix in enumerate(ix_by_sender[:-1])
43+
for rhs_ix in ix_by_sender[(idx + 1) :]
44+
)
45+
46+
47+
@pytest.mark.skipif_incompatible_mpi
48+
@pytest.mark.skipif_missing_threads
49+
@pytest.mark.parametrize("gen_model", ["sinusoidal_poisson_generator", "sinusoidal_gamma_generator"])
50+
@pytest.mark.parametrize("num_threads", [1, 2])
51+
@MPITestAssertEqual([1, 2, 4], debug=False, specific_assert=assert_all_spike_trains_different)
52+
def test_sinusoidal_generator_with_spike_recorder(gen_model, num_threads):
53+
"""Test spike recording for individual spike trains.
54+
55+
The test builds a network with ``num_vp x 3`` parrot neurons that
56+
receives spikes from the specified sinusoidal generator. The test
57+
ensures that different targets receive different spike trains.
58+
"""
59+
60+
nest.total_num_virtual_procs = 4
61+
nrns_per_vp = 3
62+
total_num_nrns = nest.total_num_virtual_procs * nrns_per_vp
63+
64+
parrots = nest.Create("parrot_neuron", total_num_nrns)
65+
gen = nest.Create(
66+
gen_model,
67+
params={
68+
"rate": 100,
69+
"amplitude": 50.0,
70+
"frequency": 10.0,
71+
"individual_spike_trains": True,
72+
},
73+
)
74+
srec = nest.Create(
75+
"spike_recorder",
76+
params={
77+
"record_to": "ascii",
78+
"time_in_steps": True,
79+
"label": SPIKE_LABEL.format(nest.num_processes), # noqa: F821
80+
},
81+
)
82+
83+
nest.Connect(gen, parrots)
84+
nest.Connect(parrots, srec)
85+
86+
nest.Simulate(200.0)
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# -*- coding: utf-8 -*-
2+
#
3+
# test_sinusoidal_generators_parallel_same.py
4+
#
5+
# This file is part of NEST.
6+
#
7+
# Copyright (C) 2004 The NEST Initiative
8+
#
9+
# NEST is free software: you can redistribute it and/or modify
10+
# it under the terms of the GNU General Public License as published by
11+
# the Free Software Foundation, either version 2 of the License, or
12+
# (at your option) any later version.
13+
#
14+
# NEST is distributed in the hope that it will be useful,
15+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
16+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17+
# GNU General Public License for more details.
18+
#
19+
# You should have received a copy of the GNU General Public License
20+
# along with NEST. If not, see <http://www.gnu.org/licenses/>.
21+
22+
"""
23+
Test that of sinusoidal generators work correctly also in parallel.
24+
25+
This file tests the case that all targets receive the same spike train.
26+
We cannot parametrize over the checker function, therefore we need two files.
27+
"""
28+
29+
import nest
30+
import numpy as np
31+
import pytest
32+
from mpi_test_wrapper import MPITestAssertEqual
33+
34+
35+
def assert_all_spike_trains_equal(all_res):
36+
"""Assert that for each number of process, we have identical spike trains from all senders."""
37+
38+
for spikes in all_res["spike"]:
39+
ix_by_sender = list(spikes.groupby("sender").groups.values())
40+
ref = spikes.iloc[ix_by_sender[0]].time_step
41+
assert all(np.array_equal(ref, spikes.iloc[ix].time_step) for ix in ix_by_sender[1:])
42+
43+
44+
@pytest.mark.skipif_incompatible_mpi
45+
@pytest.mark.skipif_missing_threads
46+
@pytest.mark.parametrize("gen_model", ["sinusoidal_poisson_generator", "sinusoidal_gamma_generator"])
47+
@pytest.mark.parametrize("num_threads", [1, 2])
48+
@MPITestAssertEqual([1, 2, 4], debug=False, specific_assert=assert_all_spike_trains_equal)
49+
def test_sinusoidal_generator_with_spike_recorder(gen_model, num_threads):
50+
"""Test spike recording with ``individual_spike_trains == False``.
51+
52+
The test builds a network with ``num_vp x 3`` parrot neurons that
53+
receives spikes from the specified sinusoidal generator. The test
54+
ensures that different targets receive identical spike trains.
55+
"""
56+
57+
nest.total_num_virtual_procs = 4
58+
nrns_per_vp = 3
59+
total_num_nrns = nest.total_num_virtual_procs * nrns_per_vp
60+
61+
parrots = nest.Create("parrot_neuron", total_num_nrns)
62+
gen = nest.Create(
63+
gen_model,
64+
params={
65+
"rate": 100,
66+
"amplitude": 50.0,
67+
"frequency": 10.0,
68+
"individual_spike_trains": False,
69+
},
70+
)
71+
srec = nest.Create(
72+
"spike_recorder",
73+
params={
74+
"record_to": "ascii",
75+
"time_in_steps": True,
76+
"label": SPIKE_LABEL.format(nest.num_processes), # noqa: F821
77+
},
78+
)
79+
80+
nest.Connect(gen, parrots)
81+
nest.Connect(parrots, srec)
82+
83+
nest.Simulate(200.0)

0 commit comments

Comments
 (0)