Skip to content

Commit 5c523a7

Browse files
arttianezhumeta-codesync[bot]
authored andcommitted
AllReduce ctring fallback to ctdirect if nRanks == 1
Summary: As title. This is needed to close Fault Tolerant AllReduce support gap. Note, for Fault Tolerance, `nRanks == 1` allreduce case do not need to be abort-able. The only requirement is for the errors to be caught and handled. This is a Ctran mechanism today, which is not algo specific. Reviewed By: dboyda Differential Revision: D85346276 fbshipit-source-id: 13f04445954f487d2cd7e1af0e801a62da801b83
1 parent 3b65b52 commit 5c523a7

File tree

2 files changed

+92
-0
lines changed

2 files changed

+92
-0
lines changed

comms/ctran/algos/AllReduce/AllReduce.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@ commResult_t ctranAllReduce(
3434
return ctranAllReduceARG(
3535
sendbuff, recvbuff, count, datatype, redOp, comm, stream, timeout);
3636
case NCCL_ALLREDUCE_ALGO::ctring:
37+
if (comm->statex_->nRanks() == 1) {
38+
// TODO(T242570177): this is a temp workaround for nRanks == 1.
39+
return ctranAllReduceDirect(
40+
sendbuff, recvbuff, count, datatype, redOp, comm, stream, timeout);
41+
}
3742
return ctranAllReduceRing(
3843
sendbuff, recvbuff, count, datatype, redOp, comm, stream, timeout);
3944
case NCCL_ALLREDUCE_ALGO::ctdirect:

comms/ctran/tests/CtranAllReduceTest.cc

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,4 +410,91 @@ INSTANTIATE_TEST_SUITE_P(
410410
std::make_tuple<>(8, commInt32),
411411
std::make_tuple<>(8, commInt8)));
412412

413+
class CtranAllReduceRingOneRankTest : public CtranStandaloneMultiRankBaseTest {
414+
protected:
415+
static constexpr int kNRanks = 1;
416+
static constexpr commRedOp_t kReduceOpType = commSum;
417+
static constexpr commDataType_t kDataType = commInt32;
418+
static constexpr size_t kTypeSize = sizeof(int);
419+
static constexpr size_t kBufferNElem = kBufferSize / kTypeSize;
420+
421+
void SetUp() override {
422+
setenv("NCCL_ALLREDUCE_ALGO", "ctring", 1);
423+
424+
CtranStandaloneMultiRankBaseTest::SetUp();
425+
}
426+
};
427+
428+
TEST_F(CtranAllReduceRingOneRankTest, Basic) {
429+
ASSERT_EQ(NCCL_ALLREDUCE_ALGO, NCCL_ALLREDUCE_ALGO::ctring);
430+
431+
CtranStandaloneMultiRankBaseTest::startWorkers(
432+
kNRanks, /*aborts=*/{ctran::utils::createAbort(/*enabled=*/true)});
433+
434+
run(/*rank=*/0, [this](PerRankState& state) {
435+
// set up src buffer to hold magic values, and zero out dst buffers
436+
int magic = 0xdeadbeef;
437+
int srcHost[kBufferNElem];
438+
int dstHost[kBufferNElem];
439+
for (int i = 0; i < kBufferNElem; ++i) {
440+
srcHost[i] = magic + i;
441+
}
442+
memset(dstHost, 0, kBufferSize);
443+
ASSERT_EQ(
444+
cudaSuccess,
445+
cudaMemcpy(
446+
state.srcBuffer, srcHost, kBufferSize, cudaMemcpyHostToDevice));
447+
ASSERT_EQ(cudaSuccess, cudaMemset(state.dstBuffer, 0, kBufferSize));
448+
449+
// warmup
450+
void* srcHandle;
451+
void* dstHandle;
452+
ASSERT_EQ(
453+
commSuccess,
454+
state.ctranComm->ctran_->commRegister(
455+
state.srcBuffer, kBufferSize, &srcHandle));
456+
ASSERT_EQ(
457+
commSuccess,
458+
state.ctranComm->ctran_->commRegister(
459+
state.dstBuffer, kBufferSize, &dstHandle));
460+
SCOPE_EXIT {
461+
// deregistering will happen after streamSync below
462+
state.ctranComm->ctran_->commDeregister(dstHandle);
463+
state.ctranComm->ctran_->commDeregister(srcHandle);
464+
};
465+
466+
CLOGF(INFO, "rank {} allReduce completed registration", state.rank);
467+
468+
EXPECT_EQ(
469+
commSuccess,
470+
ctranAllReduce(
471+
state.srcBuffer,
472+
state.dstBuffer,
473+
kBufferNElem,
474+
kDataType,
475+
kReduceOpType,
476+
state.ctranComm.get(),
477+
state.stream,
478+
std::nullopt,
479+
/*timeout=*/std::nullopt));
480+
481+
CLOGF(INFO, "rank {} allReduce scheduled", state.rank);
482+
483+
// ensure async execution completion and no error
484+
EXPECT_EQ(cudaSuccess, cudaStreamSynchronize(state.stream));
485+
EXPECT_EQ(commSuccess, state.ctranComm->getAsyncResult());
486+
487+
CLOGF(INFO, "rank {} allReduce task completed", state.rank);
488+
489+
// validate results
490+
ASSERT_EQ(
491+
cudaSuccess,
492+
cudaMemcpy(
493+
dstHost, state.dstBuffer, kBufferSize, cudaMemcpyDeviceToHost));
494+
for (int i = 0; i < kBufferNElem; ++i) {
495+
EXPECT_EQ(srcHost[i], dstHost[i]);
496+
}
497+
});
498+
}
499+
413500
} // namespace ctran::testing

0 commit comments

Comments
 (0)