@@ -410,4 +410,91 @@ INSTANTIATE_TEST_SUITE_P(
410410 std::make_tuple<>(8 , commInt32),
411411 std::make_tuple<>(8 , commInt8)));
412412
413+ class CtranAllReduceRingOneRankTest : public CtranStandaloneMultiRankBaseTest {
414+ protected:
415+ static constexpr int kNRanks = 1 ;
416+ static constexpr commRedOp_t kReduceOpType = commSum;
417+ static constexpr commDataType_t kDataType = commInt32;
418+ static constexpr size_t kTypeSize = sizeof (int );
419+ static constexpr size_t kBufferNElem = kBufferSize / kTypeSize ;
420+
421+ void SetUp () override {
422+ setenv (" NCCL_ALLREDUCE_ALGO" , " ctring" , 1 );
423+
424+ CtranStandaloneMultiRankBaseTest::SetUp ();
425+ }
426+ };
427+
428+ TEST_F (CtranAllReduceRingOneRankTest, Basic) {
429+ ASSERT_EQ (NCCL_ALLREDUCE_ALGO, NCCL_ALLREDUCE_ALGO::ctring);
430+
431+ CtranStandaloneMultiRankBaseTest::startWorkers (
432+ kNRanks , /* aborts=*/ {ctran::utils::createAbort (/* enabled=*/ true )});
433+
434+ run (/* rank=*/ 0 , [this ](PerRankState& state) {
435+ // set up src buffer to hold magic values, and zero out dst buffers
436+ int magic = 0xdeadbeef ;
437+ int srcHost[kBufferNElem ];
438+ int dstHost[kBufferNElem ];
439+ for (int i = 0 ; i < kBufferNElem ; ++i) {
440+ srcHost[i] = magic + i;
441+ }
442+ memset (dstHost, 0 , kBufferSize );
443+ ASSERT_EQ (
444+ cudaSuccess,
445+ cudaMemcpy (
446+ state.srcBuffer , srcHost, kBufferSize , cudaMemcpyHostToDevice));
447+ ASSERT_EQ (cudaSuccess, cudaMemset (state.dstBuffer , 0 , kBufferSize ));
448+
449+ // warmup
450+ void * srcHandle;
451+ void * dstHandle;
452+ ASSERT_EQ (
453+ commSuccess,
454+ state.ctranComm ->ctran_ ->commRegister (
455+ state.srcBuffer , kBufferSize , &srcHandle));
456+ ASSERT_EQ (
457+ commSuccess,
458+ state.ctranComm ->ctran_ ->commRegister (
459+ state.dstBuffer , kBufferSize , &dstHandle));
460+ SCOPE_EXIT {
461+ // deregistering will happen after streamSync below
462+ state.ctranComm ->ctran_ ->commDeregister (dstHandle);
463+ state.ctranComm ->ctran_ ->commDeregister (srcHandle);
464+ };
465+
466+ CLOGF (INFO, " rank {} allReduce completed registration" , state.rank );
467+
468+ EXPECT_EQ (
469+ commSuccess,
470+ ctranAllReduce (
471+ state.srcBuffer ,
472+ state.dstBuffer ,
473+ kBufferNElem ,
474+ kDataType ,
475+ kReduceOpType ,
476+ state.ctranComm .get (),
477+ state.stream ,
478+ std::nullopt ,
479+ /* timeout=*/ std::nullopt ));
480+
481+ CLOGF (INFO, " rank {} allReduce scheduled" , state.rank );
482+
483+ // ensure async execution completion and no error
484+ EXPECT_EQ (cudaSuccess, cudaStreamSynchronize (state.stream ));
485+ EXPECT_EQ (commSuccess, state.ctranComm ->getAsyncResult ());
486+
487+ CLOGF (INFO, " rank {} allReduce task completed" , state.rank );
488+
489+ // validate results
490+ ASSERT_EQ (
491+ cudaSuccess,
492+ cudaMemcpy (
493+ dstHost, state.dstBuffer , kBufferSize , cudaMemcpyDeviceToHost));
494+ for (int i = 0 ; i < kBufferNElem ; ++i) {
495+ EXPECT_EQ (srcHost[i], dstHost[i]);
496+ }
497+ });
498+ }
499+
413500} // namespace ctran::testing
0 commit comments