Skip to content

Commit a42365f

Browse files
authored
Merge pull request #3647 from heplesser/test_symm_conns_mpi
Port test_symmetric_connections_mpi from SLI to Py
2 parents c4af547 + 7267719 commit a42365f

File tree

2 files changed

+65
-115
lines changed

2 files changed

+65
-115
lines changed

testsuite/mpitests/test_symmetric_connections_mpi.sli

Lines changed: 0 additions & 115 deletions
This file was deleted.
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# -*- coding: utf-8 -*-
2+
#
3+
# test_symmetric_connections_mpi.py
4+
#
5+
# This file is part of NEST.
6+
#
7+
# Copyright (C) 2004 The NEST Initiative
8+
#
9+
# NEST is free software: you can redistribute it and/or modify
10+
# it under the terms of the GNU General Public License as published by
11+
# the Free Software Foundation, either version 2 of the License, or
12+
# (at your option) any later version.
13+
#
14+
# NEST is distributed in the hope that it will be useful,
15+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
16+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17+
# GNU General Public License for more details.
18+
#
19+
# You should have received a copy of the GNU General Public License
20+
# along with NEST. If not, see <http://www.gnu.org/licenses/>.
21+
22+
import numpy as np
23+
import pandas as pd
24+
import pytest
25+
from mpi_test_wrapper import MPITestAssertEqual
26+
27+
28+
def assert_symmetric(all_res):
29+
for conns in all_res["other"]:
30+
conns.set_index(["source", "target"], inplace=True)
31+
assert all(all(conns.loc[(s, t)] == conns.loc[(t, s)]) for (s, t) in conns.index)
32+
33+
34+
@pytest.mark.skipif_incompatible_mpi
35+
@pytest.mark.skipif_missing_threads
36+
@MPITestAssertEqual([1, 2, 4], debug=False, specific_assert=assert_symmetric)
37+
def test_symmetric_connections_mpi():
38+
"""
39+
Confirm that symmetric connections are created correctly.
40+
"""
41+
42+
import nest
43+
import pandas as pd # noqa: F811
44+
45+
nest.total_num_virtual_procs = 4
46+
47+
N = 5
48+
pop1 = nest.Create("parrot_neuron", 5)
49+
pop2 = nest.Create("parrot_neuron", 5)
50+
51+
nest.Connect(
52+
pop1,
53+
pop2,
54+
{"rule": "one_to_one", "make_symmetric": True},
55+
{
56+
"synapse_model": "stdp_synapse",
57+
"weight": np.linspace(1, 5, num=N),
58+
"delay": np.linspace(11, 15, num=N),
59+
"alpha": np.linspace(21, 25, num=N),
60+
},
61+
)
62+
63+
conns = pd.DataFrame.from_dict(nest.GetConnections().get(["source", "target", "weight", "delay", "alpha"]))
64+
65+
conns.to_csv(OTHER_LABEL.format(nest.num_processes, nest.Rank()), index=False, sep="\t") # noqa: F821

0 commit comments

Comments
 (0)