88
99#include " gloo/transport/tcp/context.h"
1010
11+ #include < algorithm>
12+ #include < cstdint>
1113#include < cstring>
1214#include < iostream>
15+ #include < string>
1316
14- #include " gloo/common/error.h"
1517#include " gloo/common/logging.h"
1618#include " gloo/common/utils.h"
1719#include " gloo/transport/tcp/device.h"
@@ -22,6 +24,8 @@ namespace gloo {
2224namespace transport {
2325namespace tcp {
2426
27+ constexpr int kDefaultBatchSize = 128 ;
28+
2529Context::Context (std::shared_ptr<Device> device, int rank, int size)
2630 : ::gloo::transport::Context(rank, size), device_(std::move(device)) {}
2731
@@ -78,12 +82,36 @@ void Context::createAndConnectAllPairs(IStore& store) {
7882 // which does not have the rank info hosted at a higher `Pair` level).
7983 // So better safe than sorry for now we try to minimize the changeset needed.
8084 const auto & currentRankPair = getPair (rank);
81- auto deviceAddress = Address (
85+ const auto & deviceAddress = Address (
8286 static_cast <const Pair*>(currentRankPair.get ())->address ().getSockaddr ());
8387 Rank currentRankInfo (
8488 localHostName, deviceAddress.bytes (), std::move (pairIdentifiers));
8589 store.set (std::to_string (rank), currentRankInfo.bytes ());
8690
91+ std::vector<std::vector<char >> remoteRankInfos;
92+ int key = 0 ;
93+ if (isStoreExtendedApiEnabled () && store.has_v2_support ()) {
94+ auto sizeRemaining = size;
95+ while (sizeRemaining > 0 ) {
96+ const auto batchKeys = std::min (kDefaultBatchSize , sizeRemaining);
97+ std::vector<std::string> keys (batchKeys);
98+ std::generate_n (
99+ keys.begin (), batchKeys, [&] { return std::to_string (key++); });
100+ const auto & batchRemoteInfos = store.multi_get (keys);
101+ remoteRankInfos.insert (
102+ remoteRankInfos.end (),
103+ batchRemoteInfos.begin (),
104+ batchRemoteInfos.end ());
105+ sizeRemaining -= batchKeys;
106+ }
107+ } else {
108+ std::generate_n (std::back_inserter (remoteRankInfos), size, [&] {
109+ const auto & keyStr = std::to_string (key++);
110+ store.wait ({keyStr.c_str ()}, getTimeout ());
111+ return store.get (keyStr);
112+ });
113+ }
114+
87115 // Connect every pair
88116 for (int i = 0 ; i < size; i++) {
89117 if (i == rank) {
@@ -95,24 +123,18 @@ void Context::createAndConnectAllPairs(IStore& store) {
95123 continue ;
96124 }
97125
98- // Wait for address of other side of this pair to become available
99- std::ostringstream key;
100- key << i;
101- store.wait ({key.str ()}, getTimeout ());
126+ Rank remoteRankInfo (remoteRankInfos[i]);
102127
103- // Connect to other side of this pair
104- std::vector<char > rankInfoBytes = store.get (key.str ());
105- Rank remoteRankInfo (rankInfoBytes);
106- const auto & remoteHostname = remoteRankInfo.hostname ;
107- if (!localRankSet && remoteHostname == localHostName) {
128+ if (!localRankSet && remoteRankInfo.hostname == localHostName) {
108129 ++localRank;
109130 }
110131
111132 const auto & pair = getPair (i);
112133 auto remoteDeviceAddr = Address (remoteRankInfo.addressBytes ).getSockaddr ();
113134 auto remoteAddr = Address (
114135 remoteDeviceAddr,
115- useRankAsSeqNum ? (ssize_t )rank : remoteRankInfo.pairIdentifiers [rank]);
136+ useRankAsSeqNum ? (sequence_number_t )rank
137+ : remoteRankInfo.pairIdentifiers [rank]);
116138 pair->connect (remoteAddr.bytes ());
117139 }
118140
0 commit comments