Skip to content

Commit 5f6d0e8

Browse files
function47meta-codesync[bot]
authored andcommitted
Use CtranComm hints to override NCCL_CTRAN_BACKENDS
Summary: This diff is to achieve per comm tcpdm backend override on current NCCL_CTRAN_BACKENDS selection through hints. Once the `mccl_tcpdm_backend_override == all`, it will only create tcp backend instead of based on CVARs. Added a separate test suite for the tcpdm related tests in order to run it remotely on production server through `-c comms.hosts` option. Reviewed By: saifhhasan Differential Revision: D85455341 fbshipit-source-id: 7a217ce4e41d44b726c20fd627a442c1a18bbb83
1 parent b5f167f commit 5f6d0e8

File tree

4 files changed

+193
-15
lines changed

4 files changed

+193
-15
lines changed

comms/ctran/mapper/CtranMapper.cc

Lines changed: 46 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,48 @@
2525
#endif
2626

2727
using namespace ncclx;
28+
namespace {
29+
std::vector<CtranMapperBackend> getToEnableBackends(
30+
const std::vector<CommBackend>& overrideBackend) {
31+
const std::unordered_map<enum CommBackend, CtranMapperBackend>
32+
CommBackendMap = {
33+
{CommBackend::UNSET, CtranMapperBackend::UNSET},
34+
{CommBackend::IB, CtranMapperBackend::IB},
35+
{CommBackend::NVL, CtranMapperBackend::NVL},
36+
{CommBackend::SOCKET, CtranMapperBackend::SOCKET},
37+
{CommBackend::TCPDM, CtranMapperBackend::TCPDM}};
38+
39+
const std::unordered_map<enum NCCL_CTRAN_BACKENDS, CtranMapperBackend>
40+
NCCLCtranBackendMap = {
41+
{NCCL_CTRAN_BACKENDS::ib, CtranMapperBackend::IB},
42+
{NCCL_CTRAN_BACKENDS::nvl, CtranMapperBackend::NVL},
43+
{NCCL_CTRAN_BACKENDS::socket, CtranMapperBackend::SOCKET},
44+
{NCCL_CTRAN_BACKENDS::tcpdm, CtranMapperBackend::TCPDM}};
45+
46+
std::vector<CtranMapperBackend> enableBackends;
47+
48+
if (overrideBackend.size() == 0 ||
49+
(overrideBackend.size() == 1 &&
50+
overrideBackend[0] == CommBackend::UNSET)) {
51+
for (auto& b : NCCL_CTRAN_BACKENDS) {
52+
enableBackends.emplace_back(NCCLCtranBackendMap.at(b));
53+
}
54+
} else {
55+
CLOGF(
56+
WARN,
57+
"CTRAN-MAPPER: Try to override backends through Ctran Config. Currently it is specific config for MCCL. If you are using NCCL with NCCL_CTRAN_BACKENDS, please report this to MCCL team");
58+
for (auto& b : overrideBackend) {
59+
if (b == CommBackend::UNSET) {
60+
FB_ERRORTHROW(
61+
commInvalidUsage, "CTRAN-MAPPER: Invalid override backend UNSET");
62+
}
63+
enableBackends.emplace_back(CommBackendMap.at(b));
64+
}
65+
}
66+
67+
return enableBackends;
68+
}
69+
} // namespace
2870

2971
CtranMapper::CtranMapper(CtranComm* comm) {
3072
const auto statex = comm->statex_.get();
@@ -39,26 +81,15 @@ CtranMapper::CtranMapper(CtranComm* comm) {
3981
statex->nRanks()};
4082

4183
this->comm = comm;
42-
84+
auto backendsToEnable = getToEnableBackends(comm->config_.backends);
4385
/* check user preference for backends */
4486
std::vector<bool> enableBackends(CtranMapperBackend::NUM_BACKENDS, false);
4587
iPutCount = std::vector<int>(CtranMapperBackend::NUM_BACKENDS, 0);
4688
iGetCount = std::vector<int>(CtranMapperBackend::NUM_BACKENDS, 0);
4789
std::vector<std::string> enableBackendsStrs;
48-
for (auto b : NCCL_CTRAN_BACKENDS) {
49-
if (b == NCCL_CTRAN_BACKENDS::ib) {
50-
enableBackends[CtranMapperBackend::IB] = true;
51-
enableBackendsStrs.push_back(backendToStr(CtranMapperBackend::IB));
52-
} else if (b == NCCL_CTRAN_BACKENDS::nvl) {
53-
enableBackends[CtranMapperBackend::NVL] = true;
54-
enableBackendsStrs.push_back(backendToStr(CtranMapperBackend::NVL));
55-
} else if (b == NCCL_CTRAN_BACKENDS::socket) {
56-
enableBackends[CtranMapperBackend::SOCKET] = true;
57-
enableBackendsStrs.push_back(backendToStr(CtranMapperBackend::SOCKET));
58-
} else if (b == NCCL_CTRAN_BACKENDS::tcpdm) {
59-
enableBackends[CtranMapperBackend::TCPDM] = true;
60-
enableBackendsStrs.push_back(backendToStr(CtranMapperBackend::TCPDM));
61-
}
90+
for (auto b : backendsToEnable) {
91+
enableBackends.at(b) = true;
92+
enableBackendsStrs.push_back(backendToStr(b));
6293
}
6394

6495
CLOGF_SUBSYS(
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
3+
#include <gmock/gmock.h>
4+
#include <gtest/gtest.h>
5+
6+
#include <cstdlib>
7+
#include <memory>
8+
9+
#include "comms/ctran/mapper/CtranMapper.h"
10+
#include "comms/ctran/mapper/CtranMapperImpl.h"
11+
#include "comms/ctran/mapper/CtranMapperRegMem.h"
12+
#include "comms/ctran/tests/CtranXPlatUtUtils.h"
13+
#include "comms/testinfra/TestUtils.h"
14+
#include "comms/testinfra/TestsDistUtils.h"
15+
#include "comms/utils/logger/LogUtils.h"
16+
17+
using ctran::CtranTcpDm;
18+
19+
class CtranMapperTcpdmTest : public ::testing::Test {
20+
public:
21+
std::unique_ptr<TestCtranCommRAII> commRAII_;
22+
CtranComm* dummyComm_{nullptr};
23+
std::unique_ptr<CtranMapper> mapper;
24+
25+
protected:
26+
void SetUp() override {
27+
setenv("TCP_DEVMEM_SKIP_AGENT", "1", 1);
28+
setenv("TCP_DEVMEM_RECONFIGURE_DEVICES", "0", 1);
29+
// TCPDM only available with certain kernel version. Skip test if with
30+
// incompatible kernel
31+
try {
32+
setenv("NCCL_CTRAN_BACKENDS", "tcpdm", 1);
33+
ncclCvarInit();
34+
auto commRAII = createDummyCtranComm();
35+
commRAII.reset();
36+
} catch (const std::runtime_error& e) {
37+
GTEST_SKIP() << "TCPDM backend not enabled. Skip test";
38+
}
39+
}
40+
void TearDown() override {
41+
unsetenv("NCCL_CTRAN_BACKENDS");
42+
unsetenv("MCCL_CTRAN_BACKENDS");
43+
commRAII_.reset();
44+
}
45+
void createComm() {
46+
ncclCvarInit();
47+
commRAII_ = createDummyCtranComm();
48+
dummyComm_ = commRAII_->ctranComm;
49+
}
50+
};
51+
52+
TEST_F(CtranMapperTcpdmTest, EnableTCPDMBackendThroughCVARs) {
53+
setenv("NCCL_CTRAN_BACKENDS", "tcpdm", 1);
54+
ASSERT_STREQ(getenv("NCCL_CTRAN_BACKENDS"), "tcpdm");
55+
this->createComm();
56+
auto mapper = std::make_unique<CtranMapper>(this->dummyComm_);
57+
auto rank = this->dummyComm_->statex_->rank();
58+
EXPECT_FALSE(mapper->hasBackend(rank, CtranMapperBackend::IB));
59+
EXPECT_FALSE(mapper->hasBackend(rank, CtranMapperBackend::NVL));
60+
EXPECT_FALSE(mapper->hasBackend(rank, CtranMapperBackend::SOCKET));
61+
EXPECT_TRUE(mapper->hasBackend(rank, CtranMapperBackend::TCPDM));
62+
}
63+
TEST_F(CtranMapperTcpdmTest, OverrideBackendThroughHints) {
64+
setenv("NCCL_CTRAN_BACKENDS", "nvl,ib,socket", 1);
65+
ASSERT_STREQ(getenv("NCCL_CTRAN_BACKENDS"), "nvl,ib,socket");
66+
setenv("MCCL_CTRAN_BACKENDS", "tcpdm", 1);
67+
ASSERT_STREQ(getenv("MCCL_CTRAN_BACKENDS"), "tcpdm");
68+
this->createComm();
69+
auto mapper = std::make_unique<CtranMapper>(this->dummyComm_);
70+
auto rank = this->dummyComm_->statex_->rank();
71+
EXPECT_FALSE(mapper->hasBackend(rank, CtranMapperBackend::IB));
72+
EXPECT_FALSE(mapper->hasBackend(rank, CtranMapperBackend::NVL));
73+
EXPECT_FALSE(mapper->hasBackend(rank, CtranMapperBackend::SOCKET));
74+
EXPECT_TRUE(mapper->hasBackend(rank, CtranMapperBackend::TCPDM));
75+
}
76+
77+
int main(int argc, char* argv[]) {
78+
::testing::InitGoogleTest(&argc, argv);
79+
::testing::AddGlobalTestEnvironment(new DistEnvironmentBase);
80+
return RUN_ALL_TESTS();
81+
}

comms/ctran/mapper/tests/CtranMapperUT.cc

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,63 @@ class CtranMapperTest : public ::testing::Test {
5757
logGpuMemoryStats(cudaDev);
5858
}
5959
};
60+
TEST(CtranMapperUT, EnableBackendThroughCVARs) {
61+
setenv("NCCL_CTRAN_BACKENDS", "ib, nvl, socket", 1);
62+
ncclCvarInit();
63+
auto commRAII = createDummyCtranComm();
64+
auto dummyComm = commRAII->ctranComm;
65+
auto mapper = std::make_unique<CtranMapper>(dummyComm);
66+
auto rank = dummyComm->statex_->rank();
67+
EXPECT_TRUE(mapper->hasBackend(rank, CtranMapperBackend::IB));
68+
EXPECT_TRUE(mapper->hasBackend(rank, CtranMapperBackend::NVL));
69+
// Socket is disabled if ib is enabled
70+
EXPECT_FALSE(mapper->hasBackend(rank, CtranMapperBackend::SOCKET));
71+
}
72+
73+
TEST(CtranMapperUT, EnableBackendThroughCVARsWithoutIB) {
74+
setenv("NCCL_CTRAN_BACKENDS", "nvl, socket", 1);
75+
ncclCvarInit();
76+
auto commRAII = createDummyCtranComm();
77+
auto dummyComm = commRAII->ctranComm;
78+
auto mapper = std::make_unique<CtranMapper>(dummyComm);
79+
auto rank = dummyComm->statex_->rank();
80+
EXPECT_FALSE(mapper->hasBackend(rank, CtranMapperBackend::IB));
81+
EXPECT_TRUE(mapper->hasBackend(rank, CtranMapperBackend::NVL));
82+
EXPECT_TRUE(mapper->hasBackend(rank, CtranMapperBackend::SOCKET));
83+
}
84+
TEST(CtranMapperUT, EnableBackendWithMCCLBackendUNSET) {
85+
setenv("NCCL_CTRAN_BACKENDS", "nvl, socket", 1);
86+
setenv("MCCL_CTRAN_BACKENDS", "not_set", 1);
87+
ncclCvarInit();
88+
auto commRAII = createDummyCtranComm();
89+
auto dummyComm = commRAII->ctranComm;
90+
auto mapper = std::make_unique<CtranMapper>(dummyComm);
91+
auto rank = dummyComm->statex_->rank();
92+
EXPECT_FALSE(mapper->hasBackend(rank, CtranMapperBackend::IB));
93+
EXPECT_TRUE(mapper->hasBackend(rank, CtranMapperBackend::NVL));
94+
EXPECT_TRUE(mapper->hasBackend(rank, CtranMapperBackend::SOCKET));
95+
unsetenv("MCCL_CTRAN_BACKENDS");
96+
}
97+
98+
TEST(CtranMapperUT, EnableBackendWithMCCLBackendOverride) {
99+
setenv("NCCL_CTRAN_BACKENDS", "nvl, socket", 1);
100+
setenv("MCCL_CTRAN_BACKENDS", "ib", 1);
101+
ncclCvarInit();
102+
auto commRAII = createDummyCtranComm();
103+
auto dummyComm = commRAII->ctranComm;
104+
auto mapper = std::make_unique<CtranMapper>(dummyComm);
105+
auto rank = dummyComm->statex_->rank();
106+
EXPECT_TRUE(mapper->hasBackend(rank, CtranMapperBackend::IB));
107+
EXPECT_FALSE(mapper->hasBackend(rank, CtranMapperBackend::NVL));
108+
EXPECT_FALSE(mapper->hasBackend(rank, CtranMapperBackend::SOCKET));
109+
unsetenv("MCCL_CTRAN_BACKENDS");
110+
}
111+
112+
TEST(CtranMapperUT, EnableBackendThroughCVARsWithTCPandIB) {
113+
setenv("NCCL_CTRAN_BACKENDS", "nvl, ib, socket, tcpdm", 1);
114+
ncclCvarInit();
115+
EXPECT_THROW(createDummyCtranComm(), std::runtime_error);
116+
}
60117

61118
TEST(CtranMapperUT, BackendEnum) {
62119
for (int i = 0; i < CtranMapperBackend::NUM_BACKENDS; ++i) {

comms/ctran/tests/CtranXPlatUtUtils.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,15 @@ class TestCtranCommRAII {
4444
TestCtranCommRAII(std::unique_ptr<mccl::McclComm> mcclComm);
4545
CtranComm* ctranComm{nullptr};
4646

47+
~TestCtranCommRAII() {
48+
if (mcclComm_) {
49+
mcclComm_.reset();
50+
}
51+
if (ctranComm) {
52+
ctranComm->destroy();
53+
}
54+
}
55+
4756
private:
4857
std::unique_ptr<mccl::McclComm> mcclComm_;
4958
};

0 commit comments

Comments
 (0)