Skip to content

Commit 7c66db5

Browse files
authored
Merge pull request nest#3643 from heplesser/test_rate_neurons_mpi
Port test_rate_neurons_mpi from SLI to Py (and improve MPITestWrapper)
2 parents 3fed3bd + 71a1931 commit 7c66db5

File tree

8 files changed

+132
-103
lines changed

8 files changed

+132
-103
lines changed

testsuite/mpitests/test_rate_neurons_mpi.sli

Lines changed: 0 additions & 91 deletions
This file was deleted.

testsuite/pytests/sli2py_mpi/mpi_test_wrapper.py

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@
4949
- `MULTI_LABEL`
5050
- `OTHER_LABEL`
5151
They must be used as `label` for spike recorders and multimeters, respectively,
52-
or for other files for output data (CSV files). They are format strings expecting
53-
the number of processes with which NEST is run as argument.
52+
or for other files for output data (TAB-separated CSV files). They are format
53+
strings expecting the number of processes with which NEST is run as argument.
5454
- Set `debug=True` on the decorator to see debug output and keep the
5555
temporary directory that has been created (latter works only in
5656
Python 3.12 and later)
@@ -239,10 +239,13 @@ def _collect_result_by_label(self, tmpdirpath, label):
239239
label += "-{}.dat"
240240

241241
try:
242-
next(tmpdirpath.glob(label.format("*", "*")))
242+
first_file = next(tmpdirpath.glob(label.format("*", "*")))
243243
except StopIteration:
244244
return None # no data for this label
245245

246+
# Confirm we have tab-separated data. Assumes that all data have at least two columns.
247+
assert "\t" in open(first_file).read(), "All data files must be tab-separated"
248+
246249
res = {}
247250
for n_procs in self._procs_lst:
248251
data = []
@@ -255,6 +258,18 @@ def _collect_result_by_label(self, tmpdirpath, label):
255258

256259
return res
257260

261+
@staticmethod
262+
def _drop_empty_dataframes(data):
263+
"""
264+
Return list of non-empty dataframes in data.
265+
266+
The data frames collected for a given number of processes may contain empty
267+
dataframes. pandas.concat() will not support them any more in the future, so
268+
we filter them out for tests that use concat().
269+
"""
270+
271+
return [df for df in data if not df.empty]
272+
258273
def collect_results(self, tmpdirpath):
259274
"""
260275
For each of the result types, build a dictionary mapping number of MPI procs to a list of
@@ -280,36 +295,52 @@ def assert_correct_results(self, tmpdirpath):
280295
all_res = []
281296
if self._spike:
282297
# For each number of procs, combine results across VPs and sort by time and sender
298+
299+
# Include only frames containing at least one non-nan value so pandas knows datatypes.
300+
# .all() returns True for empty arrays.
283301
all_res.append(
284302
[
285-
pd.concat(spikes, ignore_index=True).sort_values(
303+
pd.concat(self._drop_empty_dataframes(spikes), ignore_index=True).sort_values(
286304
by=["time_step", "time_offset", "sender"], ignore_index=True
287305
)
288306
for spikes in self._spike.values()
289307
]
290308
)
291309

292310
if self._multi:
293-
raise NotImplementedError("MULTI is not ready yet")
311+
# For each number of procs, combine results across VPs and sort by time and sender
312+
# Include only frames containing at least one non-nan value so pandas knows datatypes.
313+
# .all() returns True for empty arrays.
314+
all_res.append(
315+
[
316+
pd.concat(self._drop_empty_dataframes(mmdata), ignore_index=True).sort_values(
317+
by=["time_step", "time_offset", "sender"], ignore_index=True
318+
)
319+
for mmdata in self._multi.values()
320+
]
321+
)
294322

295323
if self._other:
296324
# For each number of procs, combine across ranks or VPs (depends on what test has written) and
297325
# sort by all columns so that if results for different proc numbers are equal up to a permutation
298326
# of rows, the sorted frames will compare equal
327+
# Include only frames containing at least one non-nan value so pandas knows datatypes.
328+
# .all() returns True for empty arrays.
299329

300330
# next(iter(...)) returns the first value in the _other dictionary
301331
# [0] then picks the first DataFrame from that list
302332
# columns need to be converted to list() to be passed to sort_values()
303333
all_columns = list(next(iter(self._other.values()))[0].columns)
304334
all_res.append(
305335
[
306-
pd.concat(others, ignore_index=True).sort_values(by=all_columns, ignore_index=True)
336+
pd.concat(self._drop_empty_dataframes(others), ignore_index=True).sort_values(
337+
by=all_columns, ignore_index=True
338+
)
307339
for others in self._other.values()
308340
]
309341
)
310342

311343
assert all_res, "No test data collected"
312-
313344
for res in all_res:
314345
assert len(res) == len(self._procs_lst), "Could not collect data for all procs"
315346

testsuite/pytests/sli2py_mpi/test_all_to_all.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,4 @@ def test_all_to_all(N):
4040
nest.Connect(nrns, nrns, "all_to_all")
4141

4242
conns = nest.GetConnections().get(output="pandas").drop(labels=["target_thread", "port"], axis=1)
43-
conns.to_csv(OTHER_LABEL.format(nest.num_processes, nest.Rank()), index=False) # noqa: F821
43+
conns.to_csv(OTHER_LABEL.format(nest.num_processes, nest.Rank()), index=False, sep="\t") # noqa: F821

testsuite/pytests/sli2py_mpi/test_connect_array_mpi.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,4 +109,4 @@ def gids_to_delays(sgids, tgids):
109109
assert set(actual_weights) <= set(expected_weights)
110110
assert set(actual_delays) <= set(expected_delays)
111111

112-
conns.to_csv(OTHER_LABEL.format(nest.num_processes, nest.Rank()), index=False) # noqa: F821
112+
conns.to_csv(OTHER_LABEL.format(nest.num_processes, nest.Rank()), index=False, sep="\t") # noqa: F821

testsuite/pytests/sli2py_mpi/test_issue_1957.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def test_issue_1957():
4747
if pre_conns:
4848
# need to do this here, Disconnect invalidates pre_conns
4949
df = pd.DataFrame.from_dict(pre_conns.get()).drop(labels="target_thread", axis=1)
50-
df.to_csv(OTHER_LABEL.format(nest.num_processes, nest.Rank()), index=False) # noqa: F821
50+
df.to_csv(OTHER_LABEL.format(nest.num_processes, nest.Rank()), index=False, sep="\t") # noqa: F821
5151

5252
nest.Disconnect(nrn, nrn)
5353
nest.Disconnect(nrn, nrn)

testsuite/pytests/sli2py_mpi/test_issue_2119.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,5 +51,5 @@ def test_issue_2119(kind, specs):
5151
nrn = nest.Create("iaf_psc_alpha", n=4, params={"V_m": nest.CreateParameter(kind, specs)})
5252

5353
pd.DataFrame.from_dict(nrn.get(["global_id", "V_m"])).dropna().to_csv(
54-
OTHER_LABEL.format(nest.num_processes, nest.Rank()), index=False # noqa: F821
54+
OTHER_LABEL.format(nest.num_processes, nest.Rank()), index=False, sep="\t" # noqa: F821
5555
)
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# -*- coding: utf-8 -*-
2+
#
3+
# test_rate_neurons_mpi.py
4+
#
5+
# This file is part of NEST.
6+
#
7+
# Copyright (C) 2004 The NEST Initiative
8+
#
9+
# NEST is free software: you can redistribute it and/or modify
10+
# it under the terms of the GNU General Public License as published by
11+
# the Free Software Foundation, either version 2 of the License, or
12+
# (at your option) any later version.
13+
#
14+
# NEST is distributed in the hope that it will be useful,
15+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
16+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17+
# GNU General Public License for more details.
18+
#
19+
# You should have received a copy of the GNU General Public License
20+
# along with NEST. If not, see <http://www.gnu.org/licenses/>.
21+
22+
import pytest
23+
from mpi_test_wrapper import MPITestAssertEqual
24+
25+
26+
@pytest.mark.skipif_incompatible_mpi
27+
@pytest.mark.skipif_missing_threads
28+
@pytest.mark.skipif_missing_gsl
29+
@pytest.mark.parametrize(
30+
"neuron_model, params1, params2, synspec",
31+
[
32+
[
33+
"lin_rate_ipn",
34+
{"mu": 0, "sigma": 0, "rate": 20},
35+
{"mu": 0, "sigma": 0},
36+
{"synapse_model": "rate_connection_instantaneous", "weight": 5},
37+
],
38+
[
39+
"siegert_neuron",
40+
{"rate": 20},
41+
{},
42+
{"synapse_model": "diffusion_connection", "diffusion_factor": 2, "drift_factor": 4},
43+
],
44+
],
45+
)
46+
@MPITestAssertEqual([1, 2], debug=False)
47+
def test_rate_neurons_mpi(neuron_model, params1, params2, synspec):
48+
"""
49+
Test that rate neurons are simulated correctly in parallel.
50+
51+
The test is performed on the multimeter data recorded to MULTI_LABEL during the simulation.
52+
"""
53+
54+
import nest
55+
56+
total_vps = 4
57+
h = 0.1
58+
59+
nest.SetKernelStatus(
60+
{
61+
"total_num_virtual_procs": total_vps,
62+
"resolution": h,
63+
"use_wfr": True,
64+
"wfr_tol": 0.0001,
65+
"wfr_interpolation_order": 3,
66+
"wfr_max_iterations": 10,
67+
"wfr_comm_interval": 1.0,
68+
}
69+
)
70+
71+
neuron1 = nest.Create(neuron_model, params=params1)
72+
neuron2 = nest.Create(neuron_model, params=params2)
73+
mm = nest.Create(
74+
"multimeter",
75+
params={
76+
"record_from": ["rate"],
77+
"interval": 1,
78+
"record_to": "ascii",
79+
"precision": 8,
80+
"time_in_steps": True,
81+
"label": MULTI_LABEL.format(nest.num_processes), # noqa: F821
82+
},
83+
)
84+
85+
nest.Connect(mm, neuron1 + neuron2)
86+
87+
nest.Connect(neuron1, neuron2, syn_spec=synspec)
88+
89+
nest.Simulate(11)

testsuite/pytests/sli2py_mpi/test_self_get_conns_with_empty_ranks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,4 @@ def test_get_conns_with_empty_ranks():
4040
nest.Connect(nrns, nrns)
4141

4242
conns = nest.GetConnections().get(output="pandas").drop(labels=["target_thread", "port"], axis=1, errors="ignore")
43-
conns.to_csv(OTHER_LABEL.format(nest.num_processes, nest.Rank()), index=False) # noqa: F821
43+
conns.to_csv(OTHER_LABEL.format(nest.num_processes, nest.Rank()), index=False, sep="\t") # noqa: F821

0 commit comments

Comments
 (0)