Skip to content

Commit a46b518

Browse files
arttianezhumeta-codesync[bot]
authored andcommitted
make AllReduce ctring fallback to ctdirect if message size is small
Summary: As title. Reviewed By: function47 Differential Revision: D85377269 fbshipit-source-id: 1df7f76d9ee632ddd0b0d55e93ba00c10842af91
1 parent ddf456d commit a46b518

File tree

2 files changed

+95
-64
lines changed

2 files changed

+95
-64
lines changed

comms/ctran/algos/AllReduce/AllReduce.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,15 @@ commResult_t ctranAllReduce(
3939
return ctranAllReduceDirect(
4040
sendbuff, recvbuff, count, datatype, redOp, comm, stream, timeout);
4141
}
42+
if (count < comm->statex_->nRanks()) {
43+
CLOGF(
44+
WARN,
45+
"AllReduce ctring requires count {} > nRanks {}, fallback to ctdirect",
46+
count,
47+
comm->statex_->nRanks());
48+
return ctranAllReduceDirect(
49+
sendbuff, recvbuff, count, datatype, redOp, comm, stream, timeout);
50+
}
4251
return ctranAllReduceRing(
4352
sendbuff, recvbuff, count, datatype, redOp, comm, stream, timeout);
4453
case NCCL_ALLREDUCE_ALGO::ctdirect:

comms/ctran/tests/CtranAllReduceTest.cc

Lines changed: 86 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,17 @@ TEST_P(CtranAllReduceTest, BasicRunAbortEnabled) {
149149
}
150150
}
151151

152+
TEST_P(CtranAllReduceTest, SmallMessageSize) {
153+
auto [algoName, algo] = GetParam();
154+
NCCL_ALLREDUCE_ALGO = algo;
155+
156+
startWorkers(/*abortEnabled=*/true);
157+
for (int rank = 0; rank < kNRanks; ++rank) {
158+
run(rank,
159+
[this](PerRankState& state) { runAllReduce(/*nElem=*/1, state); });
160+
}
161+
}
162+
152163
void CtranAllReduceTest::runTestRanksAbsent(
153164
std::vector<int> ranksToRunCollective,
154165
std::vector<int> ranksAbsent,
@@ -425,78 +436,89 @@ class CtranAllReduceRingOneRankTest : public CtranStandaloneMultiRankBaseTest {
425436

426437
CtranStandaloneMultiRankBaseTest::SetUp();
427438
}
428-
};
429439

430-
TEST_F(CtranAllReduceRingOneRankTest, Basic) {
431-
ASSERT_EQ(NCCL_ALLREDUCE_ALGO, NCCL_ALLREDUCE_ALGO::ctring);
432-
433-
CtranStandaloneMultiRankBaseTest::startWorkers(
434-
kNRanks, /*aborts=*/{ctran::utils::createAbort(/*enabled=*/true)});
435-
436-
run(/*rank=*/0, [this](PerRankState& state) {
437-
// set up src buffer to hold magic values, and zero out dst buffers
438-
int magic = 0xdeadbeef;
439-
int srcHost[kBufferNElem];
440-
int dstHost[kBufferNElem];
441-
for (int i = 0; i < kBufferNElem; ++i) {
442-
srcHost[i] = magic + i;
443-
}
444-
memset(dstHost, 0, kBufferSize);
445-
ASSERT_EQ(
446-
cudaSuccess,
447-
cudaMemcpy(
448-
state.srcBuffer, srcHost, kBufferSize, cudaMemcpyHostToDevice));
449-
ASSERT_EQ(cudaSuccess, cudaMemset(state.dstBuffer, 0, kBufferSize));
440+
void runAllReduce(size_t nElem) {
441+
ASSERT_EQ(NCCL_ALLREDUCE_ALGO, NCCL_ALLREDUCE_ALGO::ctring);
450442

451-
// warmup
452-
void* srcHandle;
453-
void* dstHandle;
454-
ASSERT_EQ(
455-
commSuccess,
456-
state.ctranComm->ctran_->commRegister(
457-
state.srcBuffer, kBufferSize, &srcHandle));
458-
ASSERT_EQ(
459-
commSuccess,
460-
state.ctranComm->ctran_->commRegister(
461-
state.dstBuffer, kBufferSize, &dstHandle));
462-
SCOPE_EXIT {
463-
// deregistering will happen after streamSync below
464-
state.ctranComm->ctran_->commDeregister(dstHandle);
465-
state.ctranComm->ctran_->commDeregister(srcHandle);
466-
};
443+
CtranStandaloneMultiRankBaseTest::startWorkers(
444+
kNRanks, /*aborts=*/{ctran::utils::createAbort(/*enabled=*/true)});
467445

468-
CLOGF(INFO, "rank {} allReduce completed registration", state.rank);
446+
run(/*rank=*/0, [this, nElem](PerRankState& state) {
447+
// set up src buffer to hold magic values, and zero out dst buffers
448+
int magic = 0xdeadbeef;
449+
int srcHost[kBufferNElem];
450+
int dstHost[kBufferNElem];
451+
for (int i = 0; i < kBufferNElem; ++i) {
452+
srcHost[i] = magic + i;
453+
}
454+
memset(dstHost, 0, kBufferSize);
455+
ASSERT_EQ(
456+
cudaSuccess,
457+
cudaMemcpy(
458+
state.srcBuffer, srcHost, kBufferSize, cudaMemcpyHostToDevice));
459+
ASSERT_EQ(cudaSuccess, cudaMemset(state.dstBuffer, 0, kBufferSize));
469460

470-
EXPECT_EQ(
471-
commSuccess,
472-
ctranAllReduce(
473-
state.srcBuffer,
474-
state.dstBuffer,
475-
kBufferNElem,
476-
kDataType,
477-
kReduceOpType,
478-
state.ctranComm.get(),
479-
state.stream,
480-
std::nullopt,
481-
/*timeout=*/std::nullopt));
461+
// warmup
462+
void* srcHandle;
463+
void* dstHandle;
464+
ASSERT_EQ(
465+
commSuccess,
466+
state.ctranComm->ctran_->commRegister(
467+
state.srcBuffer, kBufferSize, &srcHandle));
468+
ASSERT_EQ(
469+
commSuccess,
470+
state.ctranComm->ctran_->commRegister(
471+
state.dstBuffer, kBufferSize, &dstHandle));
472+
SCOPE_EXIT {
473+
// deregistering will happen after streamSync below
474+
state.ctranComm->ctran_->commDeregister(dstHandle);
475+
state.ctranComm->ctran_->commDeregister(srcHandle);
476+
};
477+
478+
CLOGF(INFO, "rank {} allReduce completed registration", state.rank);
479+
480+
EXPECT_EQ(
481+
commSuccess,
482+
ctranAllReduce(
483+
state.srcBuffer,
484+
state.dstBuffer,
485+
nElem,
486+
kDataType,
487+
kReduceOpType,
488+
state.ctranComm.get(),
489+
state.stream,
490+
std::nullopt,
491+
/*timeout=*/std::nullopt));
492+
493+
CLOGF(INFO, "rank {} allReduce scheduled", state.rank);
494+
495+
// ensure async execution completion and no error
496+
EXPECT_EQ(cudaSuccess, cudaStreamSynchronize(state.stream));
497+
EXPECT_EQ(commSuccess, state.ctranComm->getAsyncResult());
482498

483-
CLOGF(INFO, "rank {} allReduce scheduled", state.rank);
499+
CLOGF(INFO, "rank {} allReduce task completed", state.rank);
484500

485-
// ensure async execution completion and no error
486-
EXPECT_EQ(cudaSuccess, cudaStreamSynchronize(state.stream));
487-
EXPECT_EQ(commSuccess, state.ctranComm->getAsyncResult());
501+
// validate results
502+
ASSERT_EQ(
503+
cudaSuccess,
504+
cudaMemcpy(
505+
dstHost, state.dstBuffer, kBufferSize, cudaMemcpyDeviceToHost));
506+
for (int i = 0; i < nElem; ++i) {
507+
EXPECT_EQ(srcHost[i], dstHost[i]);
508+
}
509+
for (int i = nElem; i < kBufferNElem; ++i) {
510+
EXPECT_EQ(dstHost[i], 0);
511+
}
512+
});
513+
}
514+
};
488515

489-
CLOGF(INFO, "rank {} allReduce task completed", state.rank);
516+
TEST_F(CtranAllReduceRingOneRankTest, Basic) {
517+
this->runAllReduce(/*nElem=*/kBufferNElem);
518+
}
490519

491-
// validate results
492-
ASSERT_EQ(
493-
cudaSuccess,
494-
cudaMemcpy(
495-
dstHost, state.dstBuffer, kBufferSize, cudaMemcpyDeviceToHost));
496-
for (int i = 0; i < kBufferNElem; ++i) {
497-
EXPECT_EQ(srcHost[i], dstHost[i]);
498-
}
499-
});
520+
TEST_F(CtranAllReduceRingOneRankTest, SmallMessageSize) {
521+
this->runAllReduce(/*nElem=*/1);
500522
}
501523

502524
} // namespace ctran::testing

0 commit comments

Comments
 (0)