@@ -52,6 +52,14 @@ static const auto myAlgo = NCCL_ALLREDUCE_ALGO::ctdirect;
5252 * registers and reduces them into the receive buffer.
5353 */
5454
55+ #define THROW_IF_ABORTED (code ) \
56+ do { \
57+ code; \
58+ if (comm->testAbort ()) { \
59+ throw ctran::utils::Exception (" comm aborted" , commRemoteError); \
60+ } \
61+ } while (0 )
62+
5563static commResult_t impl (
5664 const std::vector<std::unique_ptr<struct OpElem >>& opGroup) {
5765 struct OpElem * op = opGroup.front ().get ();
@@ -250,7 +258,7 @@ static commResult_t impl(
250258 }
251259
252260 elem->post ();
253- elem->wait ();
261+ THROW_IF_ABORTED ( elem->wait (comm-> getAbort ()) );
254262
255263 /* Step 2: Inter-node Reduce-scatter */
256264 /* wait for inter-node data transfer to perform local reduction */
@@ -305,7 +313,7 @@ static commResult_t impl(
305313 elem->stridedReduce .stride = chunkCount;
306314 /* poke kernel to start the local reduction */
307315 elem->post ();
308- elem->wait ();
316+ THROW_IF_ABORTED ( elem->wait (comm-> getAbort ()) );
309317
310318 /* Step 3: Inter-node Allgather */
311319 /* wait for inter-node data transfer to perform local reduction */
@@ -346,6 +354,7 @@ static commResult_t impl(
346354 }
347355 }
348356 }
357+ THROW_IF_ABORTED ();
349358
350359 /* Step 4: Intra-node Allgather */
351360 elem = op->allreduce .kElemStepMap .at (
@@ -363,7 +372,7 @@ static commResult_t impl(
363372
364373 /* poke kernel to start the allgather */
365374 elem->post ();
366- elem->wait ();
375+ THROW_IF_ABORTED ( elem->wait (comm-> getAbort ()) );
367376
368377 op->allreduce .sendbuff =
369378 BUFOFFSET (op->allreduce .sendbuff , stepCount * typeSize);
@@ -407,7 +416,7 @@ static commResult_t impl(
407416 }
408417 }
409418 elem->post ();
410- elem->wait ();
419+ THROW_IF_ABORTED ( elem->wait (comm-> getAbort ()) );
411420
412421 /* Step 6: Intra-node bcast */
413422 elem = op->allreduce .kElemStepMap .at (
@@ -421,7 +430,7 @@ static commResult_t impl(
421430 }
422431 }
423432 elem->post ();
424- elem->wait ();
433+ THROW_IF_ABORTED ( elem->wait (comm-> getAbort ()) );
425434
426435 /* Step 7: Inter-node allreduce */
427436 /* wait for inter-node data transfer to perform local reduction */
@@ -472,7 +481,7 @@ static commResult_t impl(
472481 elem->stridedReduce .blockCount = remCount;
473482 elem->stridedReduce .stride = remCount;
474483 elem->post ();
475- elem->wait ();
484+ THROW_IF_ABORTED ( elem->wait (comm-> getAbort ()) );
476485 }
477486
478487 if (localRegSend == true ) {
@@ -693,8 +702,8 @@ commResult_t ctranAllReduceDirect(
693702
694703 opGroup.push_back (std::move (op));
695704
696- FB_COMMCHECK (
697- comm-> ctran_ -> gpe -> submit ( std::move (opGroup), impl, config, func));
705+ FB_COMMCHECK (comm-> ctran_ -> gpe -> submit (
706+ std::move (opGroup), impl, config, func, timeout ));
698707
699708 if (count * typeSize < CTRAN_MIN_REGISTRATION_SIZE) {
700709 FB_CUDACHECK (cudaMemcpyAsync (
0 commit comments