Skip to content

Commit 20dc202

Browse files
authored
allow seq_number == global_rank for TCP backend
Differential Revision: D48130088 Pull Request resolved: #407
1 parent b67ecd8 commit 20dc202

File tree

10 files changed

+84
-18
lines changed

10 files changed

+84
-18
lines changed

gloo/common/utils.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,10 @@ std::string getHostname() {
3030
return std::string(hostname);
3131
}
3232

33+
bool useRankAsSeqNumber() {
34+
const auto& res = getenv("GLOO_ENABLE_RANK_AS_SEQUENCE_NUMBER");
35+
return res != nullptr &&
36+
(std::string(res) == "True" || std::string(res) == "1");
37+
}
38+
3339
} // namespace gloo

gloo/common/utils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,6 @@ namespace gloo {
1414

1515
std::string getHostname();
1616

17+
bool useRankAsSeqNumber();
18+
1719
} // namespace gloo

gloo/transport/tcp/context.cc

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -49,15 +49,20 @@ void Context::createAndConnectAllPairs(IStore& store) {
4949
int localRank = 0;
5050
bool localRankSet = false;
5151
auto localHostName = getHostname();
52+
bool useRankAsSeqNum = useRankAsSeqNumber();
5253

5354
// We will create all the pairs including self
5455
// the self pair will not be connected
5556
// it's just to keep the later seq num matching logic simple
5657
std::vector<ssize_t> pairIdentifiers;
5758
for (int i = 0; i < size; i++) {
58-
auto& pair = createPair(i);
59-
pairIdentifiers.emplace_back(
60-
static_cast<Pair*>(pair.get())->address().getSeq());
59+
const auto& pair = createPair(i, useRankAsSeqNum);
60+
if (!useRankAsSeqNum) {
61+
// Need to preserve the order of the pair identifiers if we are not using
62+
// the rank as seq number
63+
pairIdentifiers.emplace_back(
64+
static_cast<Pair*>(pair.get())->address().getSeq());
65+
}
6166
}
6267

6368
// Obtain the pair object for this rank
@@ -105,8 +110,9 @@ void Context::createAndConnectAllPairs(IStore& store) {
105110

106111
const auto& pair = getPair(i);
107112
auto remoteDeviceAddr = Address(remoteRankInfo.addressBytes).getSockaddr();
108-
auto remoteAddr =
109-
Address(remoteDeviceAddr, remoteRankInfo.pairIdentifiers[rank]);
113+
auto remoteAddr = Address(
114+
remoteDeviceAddr,
115+
useRankAsSeqNum ? (ssize_t)rank : remoteRankInfo.pairIdentifiers[rank]);
110116
pair->connect(remoteAddr.bytes());
111117
}
112118

@@ -124,7 +130,15 @@ void Context::createAndConnectAllPairs(IStore& store) {
124130

125131
std::unique_ptr<transport::Pair>& Context::createPair(int rank) {
126132
pairs_[rank] = std::unique_ptr<transport::Pair>(
127-
new tcp::Pair(this, device_.get(), rank, getTimeout()));
133+
new tcp::Pair(this, device_.get(), rank, getTimeout(), false));
134+
return pairs_[rank];
135+
}
136+
137+
std::unique_ptr<transport::Pair>& Context::createPair(
138+
int rank,
139+
bool useRankAsSeqNumber = false) {
140+
pairs_[rank] = std::unique_ptr<transport::Pair>(new tcp::Pair(
141+
this, device_.get(), rank, getTimeout(), useRankAsSeqNumber));
128142
return pairs_[rank];
129143
}
130144

@@ -305,14 +319,16 @@ Rank::Rank(const std::vector<char>& bytes) {
305319
bytesOffset += sizeof(addrSz) + addrSz;
306320
// pair identifiers
307321
size_t pairIdChunkSz = bytes.size() - bytesOffset;
308-
GLOO_ENFORCE_EQ(
309-
pairIdChunkSz % sizeof(ssize_t),
310-
0,
311-
"Remaining bytes do not map to entire chunk of pair identifiers");
312-
size_t numPairs = pairIdChunkSz / sizeof(ssize_t);
313-
pairIdentifiers.resize(numPairs);
314-
std::memcpy(
315-
pairIdentifiers.data(), bytes.data() + bytesOffset, pairIdChunkSz);
322+
if (pairIdChunkSz) {
323+
GLOO_ENFORCE_EQ(
324+
pairIdChunkSz % sizeof(ssize_t),
325+
0,
326+
"Remaining bytes do not map to entire chunk of pair identifiers");
327+
size_t numPairs = pairIdChunkSz / sizeof(ssize_t);
328+
pairIdentifiers.resize(numPairs);
329+
std::memcpy(
330+
pairIdentifiers.data(), bytes.data() + bytesOffset, pairIdChunkSz);
331+
}
316332
}
317333

318334
std::vector<char> Rank::bytes() const {
@@ -336,7 +352,9 @@ std::vector<char> Rank::bytes() const {
336352
std::memcpy(bufOffset, addressBytes.data(), addressBytes.size());
337353
bufOffset += addrSz;
338354
// pair identifiers
339-
std::memcpy(bufOffset, pairIdentifiers.data(), pairIdChunkSz);
355+
if (pairIdChunkSz) {
356+
std::memcpy(bufOffset, pairIdentifiers.data(), pairIdChunkSz);
357+
}
340358
return buf;
341359
}
342360

gloo/transport/tcp/context.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ class Context : public ::gloo::transport::Context,
3939
virtual void createAndConnectAllPairs(IStore& store) override;
4040

4141
std::unique_ptr<transport::Pair>& createPair(int rank) override;
42+
std::unique_ptr<transport::Pair>& createPair(
43+
int rank,
44+
bool useRankAsSeqNumber);
4245

4346
std::unique_ptr<transport::UnboundBuffer> createUnboundBuffer(
4447
void* ptr,

gloo/transport/tcp/device.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,10 @@ Address Device::nextAddress() {
253253
return listener_->nextAddress();
254254
}
255255

256+
Address Device::nextAddress(int seq) {
257+
return listener_->nextAddress(seq);
258+
}
259+
256260
bool Device::isInitiator(const Address& local, const Address& remote) const {
257261
int rv = 0;
258262
// The remote side of a pair will be called with the same

gloo/transport/tcp/device.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,15 @@ class Device : public ::gloo::transport::Device,
6868
//
6969
Address nextAddress();
7070

71+
// Return a new `Address` instance using the provided sequence number.
72+
//
73+
// This is called by the constructor of the `Pair` class. It gives
74+
// the pair a uniquely identifying address even though the device
75+
// uses a shared listening socket. Caller must provide a unique sequence
76+
// number
77+
//
78+
Address nextAddress(int);
79+
7180
// Connect a pair to a remote.
7281
//
7382
// This is performed by the device instance because we use a single

gloo/transport/tcp/listener.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
#include <gloo/common/common.h>
1515
#include <gloo/common/logging.h>
16+
#include <gloo/common/utils.h>
1617
#include <gloo/transport/tcp/helpers.h>
1718

1819
namespace gloo {
@@ -26,6 +27,7 @@ Listener::Listener(std::shared_ptr<Loop> loop, const attr& attr)
2627
listener_->bind(attr.ai_addr);
2728
listener_->listen(kBacklog);
2829
addr_ = listener_->sockName();
30+
useRankAsSeqNumber_ = useRankAsSeqNumber();
2931

3032
// Register with loop for readability events.
3133
loop_->registerDescriptor(listener_->fd(), EPOLLIN, this);
@@ -76,9 +78,19 @@ void Listener::handleEvents(int /* unused */) {
7678

7779
Address Listener::nextAddress() {
7880
std::lock_guard<std::mutex> guard(mutex_);
81+
GLOO_ENFORCE(
82+
!useRankAsSeqNumber_,
83+
"Listener cannot use internal sequence with enabled option to use rank as sequence number");
7984
return Address(addr_.getSockaddr(), seq_++);
8085
}
8186

87+
Address Listener::nextAddress(int seq) {
88+
GLOO_ENFORCE(
89+
useRankAsSeqNumber_,
90+
"Listener must be setup to use rank as sequence number");
91+
return Address(addr_.getSockaddr(), seq);
92+
}
93+
8294
void Listener::waitForConnection(sequence_number_t seq, connect_callback_t fn) {
8395
std::unique_lock<std::mutex> lock(mutex_);
8496

gloo/transport/tcp/listener.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ class Listener final : public Handler {
4242

4343
Address nextAddress();
4444

45+
Address nextAddress(int);
46+
4547
// Wait for connection with sequence number `seq`. The callback is
4648
// always called from a different thread (the event loop thread),
4749
// even if the connection is already available.
@@ -67,6 +69,12 @@ class Listener final : public Handler {
6769

6870
// Sockets by sequence number (while waiting for a pair to call).
6971
std::unordered_map<sequence_number_t, std::shared_ptr<Socket>> seqToSocket_;
72+
73+
// Option to use rank as sequence number and avoid pair identifiers
74+
// to the store during rendezvous. Experimental, disabled by default.
75+
// Can be enabled by setting the environment variable
76+
// GLOO_ENABLE_RANK_AS_SEQUENCE_NUMBER=1.
77+
bool useRankAsSeqNumber_{false};
7078
};
7179

7280
} // namespace tcp

gloo/transport/tcp/pair.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ Pair::Pair(
4646
Context* context,
4747
Device* device,
4848
int rank,
49-
std::chrono::milliseconds timeout)
49+
std::chrono::milliseconds timeout,
50+
bool useRankAsSeqNumber)
5051
: context_(context),
5152
device_(device),
5253
rank_(rank),
@@ -56,7 +57,9 @@ Pair::Pair(
5657
busyPoll_(false),
5758
fd_(FD_INVALID),
5859
sendBufferSize_(0),
59-
self_(device_->nextAddress()),
60+
self_(
61+
useRankAsSeqNumber ? device_->nextAddress(rank)
62+
: device_->nextAddress()),
6063
ex_(nullptr) {}
6164

6265
// Destructor performs a "soft" close.

gloo/transport/tcp/pair.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,8 @@ class Pair : public ::gloo::transport::Pair, public Handler {
9595
Context* context,
9696
Device* device,
9797
int rank,
98-
std::chrono::milliseconds timeout);
98+
std::chrono::milliseconds timeout,
99+
bool useRankAsSeqNumber = false);
99100

100101
virtual ~Pair();
101102

0 commit comments

Comments
 (0)