Skip to content

Commit 4ff6edf

Browse files
authored
Use multi_get for store that has extended API support.
Differential Revision: D52083376 Pull Request resolved: #408
1 parent 20dc202 commit 4ff6edf

File tree

5 files changed

+58
-12
lines changed

5 files changed

+58
-12
lines changed

gloo/common/store.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,21 @@ class IStore {
2525
virtual void wait(
2626
const std::vector<std::string>& keys,
2727
const std::chrono::milliseconds& timeout) = 0;
28+
29+
// Extended 2.0 API support
30+
virtual bool has_v2_support() = 0;
31+
32+
virtual std::vector<std::vector<char>> multi_get(
33+
const std::vector<std::string>& keys) = 0;
34+
35+
virtual void multi_set(
36+
const std::vector<std::string>& keys,
37+
const std::vector<std::vector<char>>& values) = 0;
38+
39+
virtual void append(
40+
const std::string& key,
41+
const std::vector<char>& value) = 0;
42+
virtual int64_t add(const std::string& key, int64_t value) = 0;
2843
};
2944

3045
} // namespace gloo

gloo/common/utils.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,4 +36,10 @@ bool useRankAsSeqNumber() {
3636
(std::string(res) == "True" || std::string(res) == "1");
3737
}
3838

39+
bool isStoreExtendedApiEnabled() {
40+
const auto& res = std::getenv("GLOO_ENABLE_STORE_V2_API");
41+
return res != nullptr &&
42+
(std::string(res) == "True" || std::string(res) == "1");
43+
}
44+
3945
} // namespace gloo

gloo/common/utils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,6 @@ std::string getHostname();
1616

1717
bool useRankAsSeqNumber();
1818

19+
bool isStoreExtendedApiEnabled();
20+
1921
} // namespace gloo

gloo/transport/tcp/context.cc

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@
88

99
#include "gloo/transport/tcp/context.h"
1010

11+
#include <algorithm>
12+
#include <cstdint>
1113
#include <cstring>
1214
#include <iostream>
15+
#include <string>
1316

14-
#include "gloo/common/error.h"
1517
#include "gloo/common/logging.h"
1618
#include "gloo/common/utils.h"
1719
#include "gloo/transport/tcp/device.h"
@@ -22,6 +24,8 @@ namespace gloo {
2224
namespace transport {
2325
namespace tcp {
2426

27+
constexpr int kDefaultBatchSize = 128;
28+
2529
Context::Context(std::shared_ptr<Device> device, int rank, int size)
2630
: ::gloo::transport::Context(rank, size), device_(std::move(device)) {}
2731

@@ -78,12 +82,36 @@ void Context::createAndConnectAllPairs(IStore& store) {
7882
// which does not have the rank info hosted at a higher `Pair` level).
7983
// So better safe than sorry for now we try to minimize the changeset needed.
8084
const auto& currentRankPair = getPair(rank);
81-
auto deviceAddress = Address(
85+
const auto& deviceAddress = Address(
8286
static_cast<const Pair*>(currentRankPair.get())->address().getSockaddr());
8387
Rank currentRankInfo(
8488
localHostName, deviceAddress.bytes(), std::move(pairIdentifiers));
8589
store.set(std::to_string(rank), currentRankInfo.bytes());
8690

91+
std::vector<std::vector<char>> remoteRankInfos;
92+
int key = 0;
93+
if (isStoreExtendedApiEnabled() && store.has_v2_support()) {
94+
auto sizeRemaining = size;
95+
while (sizeRemaining > 0) {
96+
const auto batchKeys = std::min(kDefaultBatchSize, sizeRemaining);
97+
std::vector<std::string> keys(batchKeys);
98+
std::generate_n(
99+
keys.begin(), batchKeys, [&] { return std::to_string(key++); });
100+
const auto& batchRemoteInfos = store.multi_get(keys);
101+
remoteRankInfos.insert(
102+
remoteRankInfos.end(),
103+
batchRemoteInfos.begin(),
104+
batchRemoteInfos.end());
105+
sizeRemaining -= batchKeys;
106+
}
107+
} else {
108+
std::generate_n(std::back_inserter(remoteRankInfos), size, [&] {
109+
const auto& keyStr = std::to_string(key++);
110+
store.wait({keyStr.c_str()}, getTimeout());
111+
return store.get(keyStr);
112+
});
113+
}
114+
87115
// Connect every pair
88116
for (int i = 0; i < size; i++) {
89117
if (i == rank) {
@@ -95,24 +123,18 @@ void Context::createAndConnectAllPairs(IStore& store) {
95123
continue;
96124
}
97125

98-
// Wait for address of other side of this pair to become available
99-
std::ostringstream key;
100-
key << i;
101-
store.wait({key.str()}, getTimeout());
126+
Rank remoteRankInfo(remoteRankInfos[i]);
102127

103-
// Connect to other side of this pair
104-
std::vector<char> rankInfoBytes = store.get(key.str());
105-
Rank remoteRankInfo(rankInfoBytes);
106-
const auto& remoteHostname = remoteRankInfo.hostname;
107-
if (!localRankSet && remoteHostname == localHostName) {
128+
if (!localRankSet && remoteRankInfo.hostname == localHostName) {
108129
++localRank;
109130
}
110131

111132
const auto& pair = getPair(i);
112133
auto remoteDeviceAddr = Address(remoteRankInfo.addressBytes).getSockaddr();
113134
auto remoteAddr = Address(
114135
remoteDeviceAddr,
115-
useRankAsSeqNum ? (ssize_t)rank : remoteRankInfo.pairIdentifiers[rank]);
136+
useRankAsSeqNum ? (sequence_number_t)rank
137+
: remoteRankInfo.pairIdentifiers[rank]);
116138
pair->connect(remoteAddr.bytes());
117139
}
118140

gloo/transport/tcp/listener.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <gloo/common/logging.h>
1616
#include <gloo/common/utils.h>
1717
#include <gloo/transport/tcp/helpers.h>
18+
#
1819

1920
namespace gloo {
2021
namespace transport {

0 commit comments

Comments
 (0)