Skip to content

Commit ddf456d

Browse files
arttianezhumeta-codesync[bot]
authored andcommitted
AllReduce ctdirect enable FT
Summary: As title, enable Fault Tolerance with AllReduce `ctdirect` algorithm, which can then be used as fallback for `ctring`. Reviewed By: saifhhasan Differential Revision: D85370481 fbshipit-source-id: 55921a2d59d8c9ddf714afc7863a97939d2bdb87
1 parent 598a9e5 commit ddf456d

File tree

4 files changed

+22
-14
lines changed

4 files changed

+22
-14
lines changed

comms/ctran/algos/AllReduce/AllReduce.cc

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,6 @@ commResult_t ctranAllReduce(
4343
sendbuff, recvbuff, count, datatype, redOp, comm, stream, timeout);
4444
case NCCL_ALLREDUCE_ALGO::ctdirect:
4545
default:
46-
if (timeout != std::nullopt) {
47-
CLOGF(WARN, "timeout is ignored for AllReduce ctdirect algorithm");
48-
}
4946
return ctranAllReduceDirect(
5047
sendbuff, recvbuff, count, datatype, redOp, comm, stream, timeout);
5148
}

comms/ctran/algos/AllReduce/AllReduceDirect.cc

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,14 @@ static const auto myAlgo = NCCL_ALLREDUCE_ALGO::ctdirect;
5252
* registers and reduces them into the receive buffer.
5353
*/
5454

55+
#define THROW_IF_ABORTED(code) \
56+
do { \
57+
code; \
58+
if (comm->testAbort()) { \
59+
throw ctran::utils::Exception("comm aborted", commRemoteError); \
60+
} \
61+
} while (0)
62+
5563
static commResult_t impl(
5664
const std::vector<std::unique_ptr<struct OpElem>>& opGroup) {
5765
struct OpElem* op = opGroup.front().get();
@@ -250,7 +258,7 @@ static commResult_t impl(
250258
}
251259

252260
elem->post();
253-
elem->wait();
261+
THROW_IF_ABORTED(elem->wait(comm->getAbort()));
254262

255263
/* Step 2: Inter-node Reduce-scatter */
256264
/* wait for inter-node data transfer to perform local reduction */
@@ -305,7 +313,7 @@ static commResult_t impl(
305313
elem->stridedReduce.stride = chunkCount;
306314
/* poke kernel to start the local reduction */
307315
elem->post();
308-
elem->wait();
316+
THROW_IF_ABORTED(elem->wait(comm->getAbort()));
309317

310318
/* Step 3: Inter-node Allgather */
311319
/* wait for inter-node data transfer to perform local reduction */
@@ -346,6 +354,7 @@ static commResult_t impl(
346354
}
347355
}
348356
}
357+
THROW_IF_ABORTED();
349358

350359
/* Step 4: Intra-node Allgather */
351360
elem = op->allreduce.kElemStepMap.at(
@@ -363,7 +372,7 @@ static commResult_t impl(
363372

364373
/* poke kernel to start the allgather */
365374
elem->post();
366-
elem->wait();
375+
THROW_IF_ABORTED(elem->wait(comm->getAbort()));
367376

368377
op->allreduce.sendbuff =
369378
BUFOFFSET(op->allreduce.sendbuff, stepCount * typeSize);
@@ -407,7 +416,7 @@ static commResult_t impl(
407416
}
408417
}
409418
elem->post();
410-
elem->wait();
419+
THROW_IF_ABORTED(elem->wait(comm->getAbort()));
411420

412421
/* Step 6: Intra-node bcast */
413422
elem = op->allreduce.kElemStepMap.at(
@@ -421,7 +430,7 @@ static commResult_t impl(
421430
}
422431
}
423432
elem->post();
424-
elem->wait();
433+
THROW_IF_ABORTED(elem->wait(comm->getAbort()));
425434

426435
/* Step 7: Inter-node allreduce */
427436
/* wait for inter-node data transfer to perform local reduction */
@@ -472,7 +481,7 @@ static commResult_t impl(
472481
elem->stridedReduce.blockCount = remCount;
473482
elem->stridedReduce.stride = remCount;
474483
elem->post();
475-
elem->wait();
484+
THROW_IF_ABORTED(elem->wait(comm->getAbort()));
476485
}
477486

478487
if (localRegSend == true) {
@@ -693,8 +702,8 @@ commResult_t ctranAllReduceDirect(
693702

694703
opGroup.push_back(std::move(op));
695704

696-
FB_COMMCHECK(
697-
comm->ctran_->gpe->submit(std::move(opGroup), impl, config, func));
705+
FB_COMMCHECK(comm->ctran_->gpe->submit(
706+
std::move(opGroup), impl, config, func, timeout));
698707

699708
if (count * typeSize < CTRAN_MIN_REGISTRATION_SIZE) {
700709
FB_CUDACHECK(cudaMemcpyAsync(

comms/ctran/algos/AllReduce/AllReduceDirect.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,12 +79,12 @@ __global__ void ncclKernelAllReduceCtranDirect(
7979
const auto tId = threadIdx.x;
8080
const auto bId = blockIdx.x;
8181

82+
devStateLoadToShm(&flag[bId], devState);
83+
8284
if (flag && tId == 0) {
8385
ctran::device::KernelStartGpe(&flag[bId]);
8486
}
8587

86-
devStateLoadToShm(devState);
87-
8888
const auto nLocalRanks = statex->nLocalRanks();
8989
const auto localRank = statex->localRank();
9090

comms/ctran/tests/CtranAllReduceTest.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,9 @@ TEST_P(CtranAllReduceTest, Rank2AbsentTimeout) {
253253
INSTANTIATE_TEST_SUITE_P(
254254
AllCombinations,
255255
CtranAllReduceTest,
256-
::testing::Values(std::make_tuple("ctring", NCCL_ALLREDUCE_ALGO::ctring)),
256+
::testing::Values(
257+
std::make_tuple("ctring", NCCL_ALLREDUCE_ALGO::ctring),
258+
std::make_tuple("ctdirect", NCCL_ALLREDUCE_ALGO::ctdirect)),
257259
[](const ::testing::TestParamInfo<AllReduceTestParam>& info) {
258260
return std::get<0>(info.param);
259261
});

0 commit comments

Comments
 (0)