Skip to content

Commit dfc4d3e

Browse files
authored
Merge pull request #3536 from med-ayssar/3532
Fix issue: #3532: Disconnect issue with CompressedSpikes:ON/Off
2 parents 7ab0c23 + e8c610d commit dfc4d3e

File tree

10 files changed

+248
-82
lines changed

10 files changed

+248
-82
lines changed

libnestutil/nest_types.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ constexpr uint8_t NUM_BITS_PROCESSED_FLAG = 1U;
9292
constexpr uint8_t NUM_BITS_MARKER_SPIKE_DATA = 2U;
9393
constexpr uint8_t NUM_BITS_LAG = 14U;
9494
constexpr uint8_t NUM_BITS_DELAY = 21U;
95-
constexpr uint8_t NUM_BITS_NODE_ID = 62U;
95+
constexpr uint8_t NUM_BITS_NODE_ID = 61U;
9696

9797
// Maximally allowed values for bitfields
9898

models/eprop_synapse.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,6 @@ Connector< eprop_synapse< TargetIdentifierPtrRport > >::disable_connection( cons
108108
{
109109
assert( not C_[ lcid ].is_disabled() );
110110
C_[ lcid ].disable();
111-
C_[ lcid ].delete_optimizer();
112111
}
113112

114113
template <>
@@ -117,7 +116,6 @@ Connector< eprop_synapse< TargetIdentifierIndex > >::disable_connection( const s
117116
{
118117
assert( not C_[ lcid ].is_disabled() );
119118
C_[ lcid ].disable();
120-
C_[ lcid ].delete_optimizer();
121119
}
122120

123121

@@ -136,7 +134,10 @@ Connector< eprop_synapse< TargetIdentifierIndex > >::~Connector()
136134
{
137135
for ( auto& c : C_ )
138136
{
139-
c.delete_optimizer();
137+
if ( not c.is_disabled() )
138+
{
139+
c.delete_optimizer();
140+
}
140141
}
141142
C_.clear();
142143
}

models/eprop_synapse_bsshslm_2020.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,6 @@ Connector< eprop_synapse_bsshslm_2020< TargetIdentifierPtrRport > >::disable_con
113113
{
114114
assert( not C_[ lcid ].is_disabled() );
115115
C_[ lcid ].disable();
116-
C_[ lcid ].delete_optimizer();
117116
}
118117

119118
template <>
@@ -122,7 +121,6 @@ Connector< eprop_synapse_bsshslm_2020< TargetIdentifierIndex > >::disable_connec
122121
{
123122
assert( not C_[ lcid ].is_disabled() );
124123
C_[ lcid ].disable();
125-
C_[ lcid ].delete_optimizer();
126124
}
127125

128126

nestkernel/connection_manager.cpp

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
#include "kernel_manager.h"
5656
#include "mpi_manager_impl.h"
5757
#include "nest_names.h"
58+
#include "nest_types.h"
5859
#include "node.h"
5960
#include "sonata_connector.h"
6061
#include "stopwatch_impl.h"
@@ -976,23 +977,25 @@ nest::ConnectionManager::find_connection( const size_t tid,
976977
const size_t snode_id,
977978
const size_t tnode_id )
978979
{
979-
// lcid will hold the position of the /first/ connection from node
980-
// snode_id to any local node, or be invalid
981-
size_t lcid = source_table_.find_first_source( tid, syn_id, snode_id );
982-
if ( lcid == invalid_index )
980+
if ( use_compressed_spikes_ )
983981
{
984-
return invalid_index;
985-
}
982+
const size_t source_index = source_table_.find_first_source( tid, syn_id, snode_id );
983+
if ( source_index == invalid_index )
984+
{
985+
return invalid_index;
986+
}
987+
988+
// lcid will hold the position of the /first/ enabled connection from node
989+
// snode_id to node tnode_id, or be invalid
990+
const size_t lcid = connections_[ tid ][ syn_id ]->find_first_target( tid, source_index, tnode_id );
986991

987-
// lcid will hold the position of the /first/ connection from node
988-
// snode_id to node tnode_id, or be invalid
989-
lcid = connections_[ tid ][ syn_id ]->find_first_target( tid, lcid, tnode_id );
990-
if ( lcid != invalid_index )
991-
{
992992
return lcid;
993993
}
994-
995-
return lcid;
994+
else
995+
{
996+
return connections_[ tid ][ syn_id ]->find_enabled_connection( tid, syn_id, snode_id, tnode_id, source_table_ );
997+
}
998+
return invalid_index;
996999
}
9971000

9981001
void
@@ -1003,7 +1006,7 @@ nest::ConnectionManager::disconnect( const size_t tid,
10031006
{
10041007
assert( syn_id != invalid_synindex );
10051008

1006-
const size_t lcid = find_connection( tid, syn_id, snode_id, tnode_id );
1009+
const auto lcid = find_connection( tid, syn_id, snode_id, tnode_id );
10071010

10081011
if ( lcid == invalid_index ) // this function should only be called
10091012
// with a valid connection
@@ -1450,7 +1453,7 @@ nest::ConnectionManager::sort_connections( const size_t tid )
14501453
connections_[ tid ][ syn_id ]->sort_connections( source_table_.get_thread_local_sources( tid )[ syn_id ] );
14511454
}
14521455
}
1453-
remove_disabled_connections( tid );
1456+
remove_disabled_connections_( tid );
14541457
}
14551458
}
14561459

@@ -1700,8 +1703,10 @@ nest::ConnectionManager::compress_secondary_send_buffer_pos( const size_t tid )
17001703
}
17011704

17021705
void
1703-
nest::ConnectionManager::remove_disabled_connections( const size_t tid )
1706+
nest::ConnectionManager::remove_disabled_connections_( const size_t tid )
17041707
{
1708+
assert( use_compressed_spikes_ );
1709+
17051710
std::vector< ConnectorBase* >& connectors = connections_[ tid ];
17061711

17071712
for ( synindex syn_id = 0; syn_id < connectors.size(); ++syn_id )

nestkernel/connection_manager.h

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,11 @@ class ConnectionManager : public ManagerInterface
201201
const DictionaryDatum& third_connectivity,
202202
const std::map< Name, std::vector< DictionaryDatum > >& synapse_specs );
203203

204+
/**
205+
* Find first non-disabled thread-local connection of given synapse type with given source and target node.
206+
*
207+
* @returns Local connection id (lcid) or `invalid_index`
208+
*/
204209
size_t find_connection( const size_t tid, const synindex syn_id, const size_t snode_id, const size_t tnode_id );
205210

206211
void disconnect( const size_t tid, const synindex syn_id, const size_t snode_id, const size_t tnode_id );
@@ -388,11 +393,6 @@ class ConnectionManager : public ManagerInterface
388393
*/
389394
void sort_connections( const size_t tid );
390395

391-
/**
392-
* Removes disabled connections (of structural plasticity)
393-
*/
394-
void remove_disabled_connections( const size_t tid );
395-
396396
/**
397397
* Returns true if connection information needs to be
398398
* communicated. False otherwise.
@@ -496,6 +496,11 @@ class ConnectionManager : public ManagerInterface
496496
const size_t tnode_id,
497497
std::vector< size_t >& sources );
498498

499+
/**
500+
* Removes disabled connections (of structural plasticity)
501+
*/
502+
void remove_disabled_connections_( const size_t tid );
503+
499504
/**
500505
* Splits a TokenArray of node IDs to two vectors containing node IDs of neurons and
501506
* node IDs of devices.

nestkernel/connector_base.h

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
#include "nest_names.h"
4545
#include "node.h"
4646
#include "source.h"
47+
#include "source_table.h"
4748
#include "spikecounter.h"
4849

4950
// Includes from sli:
@@ -192,14 +193,16 @@ class ConnectorBase
192193
virtual size_t find_first_target( const size_t tid, const size_t start_lcid, const size_t target_node_id ) const = 0;
193194

194195
/**
195-
* Return lcid of first connection where the node ID of the target
196-
* matches target_node_id; consider only the connections with lcids
197-
* given in matching_lcids. If there is no match, the function returns
198-
* invalid_index.
196+
* Return lcid of first connection matching source and target node id and that
197+
* is not disabled.
198+
*
199+
* Intended for use with unsorted (uncompressed) connections.
199200
*/
200-
virtual size_t find_matching_target( const size_t tid,
201-
const std::vector< size_t >& matching_lcids,
202-
const size_t target_node_id ) const = 0;
201+
virtual size_t find_enabled_connection( const size_t tid,
202+
const size_t syn_id,
203+
const size_t source_node_id,
204+
const size_t target_node_id,
205+
const SourceTable& source_table ) const = 0;
203206

204207
/**
205208
* Disable the transfer of events through the connection at position
@@ -466,6 +469,10 @@ class Connector : public ConnectorBase
466469
size_t
467470
find_first_target( const size_t tid, const size_t start_lcid, const size_t target_node_id ) const override
468471
{
472+
// TODO: Once #3544 is merged, activate this assertion. It is currently
473+
// commented out to avoid circular inclusions.
474+
// assert( kernel().connection_manager.use_compressed_spikes() );
475+
469476
size_t lcid = start_lcid;
470477
while ( true )
471478
{
@@ -484,15 +491,18 @@ class Connector : public ConnectorBase
484491
}
485492

486493
size_t
487-
find_matching_target( const size_t tid,
488-
const std::vector< size_t >& matching_lcids,
489-
const size_t target_node_id ) const override
494+
find_enabled_connection( const size_t tid,
495+
const size_t syn_id,
496+
const size_t source_node_id,
497+
const size_t target_node_id,
498+
const SourceTable& source_table ) const override
490499
{
491-
for ( size_t i = 0; i < matching_lcids.size(); ++i )
500+
for ( size_t lcid = 0; lcid < C_.size(); ++lcid )
492501
{
493-
if ( C_[ matching_lcids[ i ] ].get_target( tid )->get_node_id() == target_node_id )
502+
if ( source_table.get_node_id( tid, syn_id, lcid ) == source_node_id
503+
and C_[ lcid ].get_target( tid )->get_node_id() == target_node_id and not C_[ lcid ].is_disabled() )
494504
{
495-
return matching_lcids[ i ];
505+
return lcid;
496506
}
497507
}
498508

nestkernel/source.h

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -43,17 +43,13 @@ class Source
4343
uint64_t node_id_ : NUM_BITS_NODE_ID; //!< node ID of source
4444
bool processed_ : 1; //!< whether this target has already been moved
4545
//!< to the MPI buffer
46-
bool primary_ : 1;
46+
bool primary_ : 1; //!< source of primary connection
47+
bool disabled_ : 1; //!< connection has been disabled
4748

4849
public:
4950
Source();
5051
explicit Source( const uint64_t node_id, const bool primary );
5152

52-
/**
53-
* Sets node_id_ to the specified value.
54-
*/
55-
void set_node_id( const uint64_t node_id );
56-
5753
/**
5854
* Returns this Source's node ID.
5955
*/
@@ -91,24 +87,19 @@ inline Source::Source()
9187
: node_id_( 0 )
9288
, processed_( false )
9389
, primary_( true )
90+
, disabled_( false )
9491
{
9592
}
9693

9794
inline Source::Source( const uint64_t node_id, const bool is_primary )
9895
: node_id_( node_id )
9996
, processed_( false )
10097
, primary_( is_primary )
98+
, disabled_( false )
10199
{
102100
assert( node_id <= MAX_NODE_ID );
103101
}
104102

105-
inline void
106-
Source::set_node_id( const uint64_t node_id )
107-
{
108-
assert( node_id <= MAX_NODE_ID );
109-
node_id_ = node_id;
110-
}
111-
112103
inline uint64_t
113104
Source::get_node_id() const
114105
{
@@ -142,13 +133,13 @@ Source::is_primary() const
142133
inline void
143134
Source::disable()
144135
{
145-
node_id_ = DISABLED_NODE_ID;
136+
disabled_ = true;
146137
}
147138

148139
inline bool
149140
Source::is_disabled() const
150141
{
151-
return node_id_ == DISABLED_NODE_ID;
142+
return disabled_;
152143
}
153144

154145
inline bool

nestkernel/source_table.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,9 +162,41 @@ nest::SourceTable::get_node_id( const size_t tid, const synindex syn_id, const s
162162
return sources_[ tid ][ syn_id ][ lcid ].get_node_id();
163163
}
164164

165+
size_t
166+
nest::SourceTable::find_first_source( const size_t tid, const synindex syn_id, const size_t snode_id ) const
167+
{
168+
const auto source_begin = sources_[ tid ][ syn_id ].begin();
169+
const auto source_end = sources_[ tid ][ syn_id ].end();
170+
171+
auto first_source_match = source_begin;
172+
if ( kernel().connection_manager.use_compressed_spikes() )
173+
{
174+
// Binary search for first entry matching snode_id; is_primary is ignored
175+
const Source requested_source { snode_id, /* is_primary */ true };
176+
first_source_match = std::lower_bound( source_begin, source_end, requested_source );
177+
}
178+
179+
// Linear search for first non-disabled connection
180+
const auto first_enabled = std::find_if( first_source_match,
181+
source_end,
182+
[ &snode_id ]( const Source& src ) { return src.get_node_id() == snode_id and not src.is_disabled(); } );
183+
if ( first_enabled != source_end )
184+
{
185+
// lcid is iterator difference
186+
return first_enabled - source_begin;
187+
}
188+
else
189+
{
190+
// no enabled entry with this snode ID found
191+
return invalid_index;
192+
}
193+
}
194+
165195
size_t
166196
nest::SourceTable::remove_disabled_sources( const size_t tid, const synindex syn_id )
167197
{
198+
assert( kernel().connection_manager.use_compressed_spikes() );
199+
168200
if ( sources_[ tid ].size() <= syn_id )
169201
{
170202
return invalid_index; // no source table entry for this synapse model

nestkernel/source_table.h

Lines changed: 6 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,6 @@ class SourceTable
137137
*/
138138
static const size_t min_deleted_elements_ = 1000000;
139139

140-
141140
/**
142141
* Returns whether this Source object should be considered when
143142
* constructing MPI buffers for communicating connections.
@@ -297,8 +296,12 @@ class SourceTable
297296
std::map< size_t, size_t >& buffer_pos_of_source_node_id_syn_id_ );
298297

299298
/**
300-
* Finds the first entry in sources_ at the given thread id and
301-
* synapse type that is equal to snode_id.
299+
* Finds the first non-disabled entry in sources_ at the given thread id and synapse type that has sender equal to
300+
* snode_id.
301+
*
302+
* @returns local connection id (lcid) or `invalid_index`
303+
*
304+
* @note For compressed spikes, it uses binary search, otherwise linear search.
302305
*/
303306
size_t find_first_source( const size_t tid, const synindex syn_id, const size_t snode_id ) const;
304307

@@ -470,30 +473,6 @@ SourceTable::no_targets_to_process( const size_t tid )
470473
current_positions_[ tid ].lcid = -1;
471474
}
472475

473-
inline size_t
474-
SourceTable::find_first_source( const size_t tid, const synindex syn_id, const size_t snode_id ) const
475-
{
476-
// binary search in sorted sources
477-
const BlockVector< Source >::const_iterator begin = sources_[ tid ][ syn_id ].begin();
478-
const BlockVector< Source >::const_iterator end = sources_[ tid ][ syn_id ].end();
479-
BlockVector< Source >::const_iterator it = std::lower_bound( begin, end, Source( snode_id, true ) );
480-
481-
// source found by binary search could be disabled, iterate through
482-
// sources until a valid one is found
483-
while ( it != end )
484-
{
485-
if ( it->get_node_id() == snode_id and not it->is_disabled() )
486-
{
487-
const size_t lcid = it - begin;
488-
return lcid;
489-
}
490-
++it;
491-
}
492-
493-
// no enabled entry with this snode ID found
494-
return invalid_index;
495-
}
496-
497476
inline void
498477
SourceTable::disable_connection( const size_t tid, const synindex syn_id, const size_t lcid )
499478
{

0 commit comments

Comments
 (0)