@@ -49,15 +49,20 @@ void Context::createAndConnectAllPairs(IStore& store) {
4949 int localRank = 0 ;
5050 bool localRankSet = false ;
5151 auto localHostName = getHostname ();
52+ bool useRankAsSeqNum = useRankAsSeqNumber ();
5253
5354 // We will create all the pairs including self
5455 // the self pair will not be connected
5556 // it's just to keep the later seq num matching logic simple
5657 std::vector<ssize_t > pairIdentifiers;
5758 for (int i = 0 ; i < size; i++) {
58- auto & pair = createPair (i);
59- pairIdentifiers.emplace_back (
60- static_cast <Pair*>(pair.get ())->address ().getSeq ());
59+ const auto & pair = createPair (i, useRankAsSeqNum);
60+ if (!useRankAsSeqNum) {
61+ // Need to preserve the order of the pair identifiers if we are not using
62+ // the rank as seq number
63+ pairIdentifiers.emplace_back (
64+ static_cast <Pair*>(pair.get ())->address ().getSeq ());
65+ }
6166 }
6267
6368 // Obtain the pair object for this rank
@@ -105,8 +110,9 @@ void Context::createAndConnectAllPairs(IStore& store) {
105110
106111 const auto & pair = getPair (i);
107112 auto remoteDeviceAddr = Address (remoteRankInfo.addressBytes ).getSockaddr ();
108- auto remoteAddr =
109- Address (remoteDeviceAddr, remoteRankInfo.pairIdentifiers [rank]);
113+ auto remoteAddr = Address (
114+ remoteDeviceAddr,
115+ useRankAsSeqNum ? (ssize_t )rank : remoteRankInfo.pairIdentifiers [rank]);
110116 pair->connect (remoteAddr.bytes ());
111117 }
112118
@@ -124,7 +130,15 @@ void Context::createAndConnectAllPairs(IStore& store) {
124130
125131std::unique_ptr<transport::Pair>& Context::createPair (int rank) {
126132 pairs_[rank] = std::unique_ptr<transport::Pair>(
127- new tcp::Pair (this , device_.get (), rank, getTimeout ()));
133+ new tcp::Pair (this , device_.get (), rank, getTimeout (), false ));
134+ return pairs_[rank];
135+ }
136+
137+ std::unique_ptr<transport::Pair>& Context::createPair (
138+ int rank,
139+ bool useRankAsSeqNumber = false ) {
140+ pairs_[rank] = std::unique_ptr<transport::Pair>(new tcp::Pair (
141+ this , device_.get (), rank, getTimeout (), useRankAsSeqNumber));
128142 return pairs_[rank];
129143}
130144
@@ -305,14 +319,16 @@ Rank::Rank(const std::vector<char>& bytes) {
305319 bytesOffset += sizeof (addrSz) + addrSz;
306320 // pair identifiers
307321 size_t pairIdChunkSz = bytes.size () - bytesOffset;
308- GLOO_ENFORCE_EQ (
309- pairIdChunkSz % sizeof (ssize_t ),
310- 0 ,
311- " Remaining bytes do not map to entire chunk of pair identifiers" );
312- size_t numPairs = pairIdChunkSz / sizeof (ssize_t );
313- pairIdentifiers.resize (numPairs);
314- std::memcpy (
315- pairIdentifiers.data (), bytes.data () + bytesOffset, pairIdChunkSz);
322+ if (pairIdChunkSz) {
323+ GLOO_ENFORCE_EQ (
324+ pairIdChunkSz % sizeof (ssize_t ),
325+ 0 ,
326+ " Remaining bytes do not map to entire chunk of pair identifiers" );
327+ size_t numPairs = pairIdChunkSz / sizeof (ssize_t );
328+ pairIdentifiers.resize (numPairs);
329+ std::memcpy (
330+ pairIdentifiers.data (), bytes.data () + bytesOffset, pairIdChunkSz);
331+ }
316332}
317333
318334std::vector<char > Rank::bytes () const {
@@ -336,7 +352,9 @@ std::vector<char> Rank::bytes() const {
336352 std::memcpy (bufOffset, addressBytes.data (), addressBytes.size ());
337353 bufOffset += addrSz;
338354 // pair identifiers
339- std::memcpy (bufOffset, pairIdentifiers.data (), pairIdChunkSz);
355+ if (pairIdChunkSz) {
356+ std::memcpy (bufOffset, pairIdentifiers.data (), pairIdChunkSz);
357+ }
340358 return buf;
341359}
342360
0 commit comments