Skip to content

Commit 72ae2a6

Browse files
function47meta-codesync[bot]
authored andcommitted
Add backend override config to ctranComm init
Summary: As title. 1. Add backend field to the ctranConfig struct. 2. Allow ctranConfig to be passed when ctranComm is created. Reviewed By: saifhhasan Differential Revision: D85444909 fbshipit-source-id: 89c98ad7a4b7b39e4b81fa12ae794ade4b32f195
1 parent 2bc40c4 commit 72ae2a6

File tree

3 files changed

+37
-10
lines changed

3 files changed

+37
-10
lines changed

comms/ctran/CtranComm.h

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,19 @@
1616
#include "comms/utils/colltrace/CollTraceInterface.h"
1717
#include "comms/utils/commSpecs.h"
1818

19+
using meta::comms::CommBackend;
1920
struct ctranConfig {
2021
int blocking{-1};
2122
const char* commDesc{nullptr};
2223
const char* ncclAllGatherAlgo{nullptr};
24+
std::vector<enum CommBackend> backends = {};
2325

2426
bool operator==(const ctranConfig& other) const {
2527
return (
2628
blocking == other.blocking && commDesc == other.commDesc &&
2729
commDesc == other.commDesc &&
28-
ncclAllGatherAlgo == other.ncclAllGatherAlgo);
30+
ncclAllGatherAlgo == other.ncclAllGatherAlgo &&
31+
backends == other.backends);
2932
}
3033
};
3134

@@ -46,8 +49,9 @@ class CtranComm {
4649
// For real communicationator we should use factory method to create.
4750
explicit CtranComm(
4851
std::shared_ptr<Abort> abort =
49-
ctran::utils::createAbort(/*enabled=*/false))
50-
: abort_(abort) {
52+
ctran::utils::createAbort(/*enabled=*/false),
53+
ctranConfig commConfig = ctranConfig{})
54+
: config_(commConfig), abort_(abort) {
5155
asyncErr_ =
5256
std::make_shared<AsyncError>(NCCL_CTRAN_ABORT_ON_ERROR, "CtranComm");
5357
if (!abort_) {
@@ -58,11 +62,12 @@ class CtranComm {
5862
opCount_ = &ctranOpCount_;
5963
}
6064

61-
// The MemCache allocator is destroyed in a different time than all other
62-
// Ctran resources. To accommodate this, we split the CtranComm destructor
63-
// into two parts. In the first part, we destroy all resources except for
64-
// MemCache. The second part is moved to the destructor, where it is safe to
65-
// destroy MemCache and reset its reference.
65+
// The MemCache allocator is destroyed in a different time than all
66+
// other Ctran resources. To accommodate this, we split the CtranComm
67+
// destructor into two parts. In the first part, we destroy all
68+
// resources except for MemCache. The second part is moved to the
69+
// destructor, where it is safe to destroy MemCache and reset its
70+
// reference.
6671
void destroy() {
6772
// All smart pointers are automatically de-initialized, but we want to
6873
// ensure they do so in a specific order. Therefore, we manually handle
@@ -72,8 +77,8 @@ class CtranComm {
7277
collTrace_.reset();
7378
colltraceNew_.reset();
7479
statex_.reset();
75-
// NOTE: memCache needs to be destroyed after transportProxy_ to release all
76-
// buffers
80+
// NOTE: memCache needs to be destroyed after transportProxy_ to release
81+
// all buffers
7782
memCache_.reset();
7883

7984
this->logMetaData_.commDesc.clear();

comms/ctran/tests/CtranCommTest.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,17 @@ TEST(CtranCommTest, AbortAvailableAndDisabled) {
4949
EXPECT_FALSE(comm.testAbort());
5050
}
5151

52+
TEST(CtranCommTest, ctranCommConfigTest) {
53+
auto abort = ctran::utils::createAbort(/*enabled=*/true);
54+
ctranConfig config = {
55+
.backends = {CommBackend::IB, CommBackend::NVL, CommBackend::SOCKET}};
56+
57+
CtranComm comm(abort, config);
58+
EXPECT_EQ(comm.config_.backends.size(), 3);
59+
60+
/// Explictly create comm with false abort as first argument is unommitable
61+
CtranComm comm2(ctran::utils::createAbort(false));
62+
EXPECT_EQ(comm2.config_.backends.size(), 0);
63+
}
64+
5265
} // namespace ctran::testing

comms/utils/commSpecs.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,15 @@ enum class CommProtocol {
299299
NumProtocols = 3 // Simple/LL/LL128
300300
};
301301

302+
enum class CommBackend {
303+
UNSET = 0,
304+
IB = 1,
305+
NVL = 2,
306+
SOCKET = 3,
307+
TCPDM = 4,
308+
NUM_BACKENDS = 5
309+
};
310+
302311
class CommsError {
303312
public:
304313
CommsError(std::string msg, commResult_t code)

0 commit comments

Comments
 (0)