@@ -149,6 +149,17 @@ TEST_P(CtranAllReduceTest, BasicRunAbortEnabled) {
149149 }
150150}
151151
152+ TEST_P (CtranAllReduceTest, SmallMessageSize) {
153+ auto [algoName, algo] = GetParam ();
154+ NCCL_ALLREDUCE_ALGO = algo;
155+
156+ startWorkers (/* abortEnabled=*/ true );
157+ for (int rank = 0 ; rank < kNRanks ; ++rank) {
158+ run (rank,
159+ [this ](PerRankState& state) { runAllReduce (/* nElem=*/ 1 , state); });
160+ }
161+ }
162+
152163void CtranAllReduceTest::runTestRanksAbsent (
153164 std::vector<int > ranksToRunCollective,
154165 std::vector<int > ranksAbsent,
@@ -425,78 +436,89 @@ class CtranAllReduceRingOneRankTest : public CtranStandaloneMultiRankBaseTest {
425436
426437 CtranStandaloneMultiRankBaseTest::SetUp ();
427438 }
428- };
429439
430- TEST_F (CtranAllReduceRingOneRankTest, Basic) {
431- ASSERT_EQ (NCCL_ALLREDUCE_ALGO, NCCL_ALLREDUCE_ALGO::ctring);
432-
433- CtranStandaloneMultiRankBaseTest::startWorkers (
434- kNRanks , /* aborts=*/ {ctran::utils::createAbort (/* enabled=*/ true )});
435-
436- run (/* rank=*/ 0 , [this ](PerRankState& state) {
437- // set up src buffer to hold magic values, and zero out dst buffers
438- int magic = 0xdeadbeef ;
439- int srcHost[kBufferNElem ];
440- int dstHost[kBufferNElem ];
441- for (int i = 0 ; i < kBufferNElem ; ++i) {
442- srcHost[i] = magic + i;
443- }
444- memset (dstHost, 0 , kBufferSize );
445- ASSERT_EQ (
446- cudaSuccess,
447- cudaMemcpy (
448- state.srcBuffer , srcHost, kBufferSize , cudaMemcpyHostToDevice));
449- ASSERT_EQ (cudaSuccess, cudaMemset (state.dstBuffer , 0 , kBufferSize ));
440+ void runAllReduce (size_t nElem) {
441+ ASSERT_EQ (NCCL_ALLREDUCE_ALGO, NCCL_ALLREDUCE_ALGO::ctring);
450442
451- // warmup
452- void * srcHandle;
453- void * dstHandle;
454- ASSERT_EQ (
455- commSuccess,
456- state.ctranComm ->ctran_ ->commRegister (
457- state.srcBuffer , kBufferSize , &srcHandle));
458- ASSERT_EQ (
459- commSuccess,
460- state.ctranComm ->ctran_ ->commRegister (
461- state.dstBuffer , kBufferSize , &dstHandle));
462- SCOPE_EXIT {
463- // deregistering will happen after streamSync below
464- state.ctranComm ->ctran_ ->commDeregister (dstHandle);
465- state.ctranComm ->ctran_ ->commDeregister (srcHandle);
466- };
443+ CtranStandaloneMultiRankBaseTest::startWorkers (
444+ kNRanks , /* aborts=*/ {ctran::utils::createAbort (/* enabled=*/ true )});
467445
468- CLOGF (INFO, " rank {} allReduce completed registration" , state.rank );
446+ run (/* rank=*/ 0 , [this , nElem](PerRankState& state) {
447+ // set up src buffer to hold magic values, and zero out dst buffers
448+ int magic = 0xdeadbeef ;
449+ int srcHost[kBufferNElem ];
450+ int dstHost[kBufferNElem ];
451+ for (int i = 0 ; i < kBufferNElem ; ++i) {
452+ srcHost[i] = magic + i;
453+ }
454+ memset (dstHost, 0 , kBufferSize );
455+ ASSERT_EQ (
456+ cudaSuccess,
457+ cudaMemcpy (
458+ state.srcBuffer , srcHost, kBufferSize , cudaMemcpyHostToDevice));
459+ ASSERT_EQ (cudaSuccess, cudaMemset (state.dstBuffer , 0 , kBufferSize ));
469460
470- EXPECT_EQ (
471- commSuccess,
472- ctranAllReduce (
473- state.srcBuffer ,
474- state.dstBuffer ,
475- kBufferNElem ,
476- kDataType ,
477- kReduceOpType ,
478- state.ctranComm .get (),
479- state.stream ,
480- std::nullopt ,
481- /* timeout=*/ std::nullopt ));
461+ // warmup
462+ void * srcHandle;
463+ void * dstHandle;
464+ ASSERT_EQ (
465+ commSuccess,
466+ state.ctranComm ->ctran_ ->commRegister (
467+ state.srcBuffer , kBufferSize , &srcHandle));
468+ ASSERT_EQ (
469+ commSuccess,
470+ state.ctranComm ->ctran_ ->commRegister (
471+ state.dstBuffer , kBufferSize , &dstHandle));
472+ SCOPE_EXIT {
473+ // deregistering will happen after streamSync below
474+ state.ctranComm ->ctran_ ->commDeregister (dstHandle);
475+ state.ctranComm ->ctran_ ->commDeregister (srcHandle);
476+ };
477+
478+ CLOGF (INFO, " rank {} allReduce completed registration" , state.rank );
479+
480+ EXPECT_EQ (
481+ commSuccess,
482+ ctranAllReduce (
483+ state.srcBuffer ,
484+ state.dstBuffer ,
485+ nElem,
486+ kDataType ,
487+ kReduceOpType ,
488+ state.ctranComm .get (),
489+ state.stream ,
490+ std::nullopt ,
491+ /* timeout=*/ std::nullopt ));
492+
493+ CLOGF (INFO, " rank {} allReduce scheduled" , state.rank );
494+
495+ // ensure async execution completion and no error
496+ EXPECT_EQ (cudaSuccess, cudaStreamSynchronize (state.stream ));
497+ EXPECT_EQ (commSuccess, state.ctranComm ->getAsyncResult ());
482498
483- CLOGF (INFO, " rank {} allReduce scheduled " , state.rank );
499+ CLOGF (INFO, " rank {} allReduce task completed " , state.rank );
484500
485- // ensure async execution completion and no error
486- EXPECT_EQ (cudaSuccess, cudaStreamSynchronize (state.stream ));
487- EXPECT_EQ (commSuccess, state.ctranComm ->getAsyncResult ());
501+ // validate results
502+ ASSERT_EQ (
503+ cudaSuccess,
504+ cudaMemcpy (
505+ dstHost, state.dstBuffer , kBufferSize , cudaMemcpyDeviceToHost));
506+ for (int i = 0 ; i < nElem; ++i) {
507+ EXPECT_EQ (srcHost[i], dstHost[i]);
508+ }
509+ for (int i = nElem; i < kBufferNElem ; ++i) {
510+ EXPECT_EQ (dstHost[i], 0 );
511+ }
512+ });
513+ }
514+ };
488515
489- CLOGF (INFO, " rank {} allReduce task completed" , state.rank );
516+ TEST_F (CtranAllReduceRingOneRankTest, Basic) {
517+ this ->runAllReduce (/* nElem=*/ kBufferNElem );
518+ }
490519
491- // validate results
492- ASSERT_EQ (
493- cudaSuccess,
494- cudaMemcpy (
495- dstHost, state.dstBuffer , kBufferSize , cudaMemcpyDeviceToHost));
496- for (int i = 0 ; i < kBufferNElem ; ++i) {
497- EXPECT_EQ (srcHost[i], dstHost[i]);
498- }
499- });
520+ TEST_F (CtranAllReduceRingOneRankTest, SmallMessageSize) {
521+ this ->runAllReduce (/* nElem=*/ 1 );
500522}
501523
502524} // namespace ctran::testing
0 commit comments