Skip to content

Commit 1cda6ae

Browse files
Saif Hasanmeta-codesync[bot]
authored andcommitted
Change submitHost parameter from exReq to cpuFlag
Summary: The GPE thread only needs the `std::atomic_flag* cpuFlag` for signaling completion, so passing the whole `CtranExRequestImpl* exReq` object is unnecessary. This simplifies using `broadcastBinomialTree` with just a flag instead of requiring the full `exReq` object. Changes include: - Updated function signatures in CtranGpe.h, CtranGpe.cc, CtranGpeImpl.h, and CtranGpeImpl.cc - Changed CtranGpeCmd struct to use `cpuFlag` instead of `exReq` - Updated GPE thread implementation to call `cpuFlag->test_and_set()` instead of `exReq->complete()` - Updated `CtranAlgo::broadcastBinomialTree` signature and implementation - Updated call sites in ncclx v2_27 and v2_28 wrappers to pass `&reqImpl->bcast.complete` This change maintains the same functionality while providing a cleaner API that only exposes what the GPE thread actually needs. Reviewed By: Regina8023 Differential Revision: D86544452 fbshipit-source-id: 7a71623db10e762cf27d852cabddd0a3f2db5420
1 parent 6abe4f7 commit 1cda6ae

File tree

7 files changed

+22
-17
lines changed

7 files changed

+22
-17
lines changed

comms/ctran/algos/Broadcast/BroadcastBinomialTree.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,7 @@ commResult_t CtranAlgo::broadcastBinomialTree(
440440
size_t count,
441441
commDataType_t datatype,
442442
int root,
443-
CtranExRequestImpl* exReq) {
443+
std::atomic_flag* cpuFlag) {
444444
auto opCount = ctran_->getOpCount();
445445
CTRAN_HOST_COLL_INFO(
446446
broadcastAlgoName(myAlgo).c_str(),
@@ -451,7 +451,7 @@ commResult_t CtranAlgo::broadcastBinomialTree(
451451
root,
452452
comm_,
453453
ctran_,
454-
exReq);
454+
cpuFlag);
455455
const auto statex = comm_->statex_.get();
456456

457457
if (sendbuff != recvbuff && statex->rank() == root) {
@@ -489,8 +489,8 @@ commResult_t CtranAlgo::broadcastBinomialTree(
489489
config.args.collective.broadcast.datatype = datatype;
490490
config.args.collective.broadcast.count = count;
491491

492-
FB_COMMCHECK(
493-
comm_->ctran_->gpe->submitHost(std::move(opGroup), impl, config, exReq));
492+
FB_COMMCHECK(comm_->ctran_->gpe->submitHost(
493+
std::move(opGroup), impl, config, cpuFlag));
494494

495495
return commSuccess;
496496
}

comms/ctran/algos/CtranAlgo.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ class CtranAlgo {
5757
size_t count,
5858
commDataType_t datatype,
5959
int root,
60-
::ctran::CtranExRequestImpl* exReq);
60+
std::atomic_flag* cpuFlag);
6161

6262
commResult_t initTmpBufs();
6363
commResult_t initAllReduceDirectResource(int nBlocks, cudaStream_t stream);

comms/ctran/gpe/CtranGpe.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -389,13 +389,13 @@ commResult_t CtranGpe::submitHost(
389389
std::vector<std::unique_ptr<struct OpElem>> opGroup,
390390
opFunc func,
391391
KernelConfig& kernelConfig,
392-
CtranExRequestImpl* exReq) {
392+
std::atomic_flag* cpuFlag) {
393393
return this->pimpl->submitHost(
394394
CtranGpeCmd::TypeEnum::GRAPH_ENQUEUE,
395395
std::move(opGroup),
396396
func,
397397
kernelConfig,
398-
exReq);
398+
cpuFlag);
399399
}
400400

401401
commResult_t CtranGpe::allocKernelElems(

comms/ctran/gpe/CtranGpe.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -385,14 +385,14 @@ class CtranGpe {
385385
// Submit host mem communication. No kernel is launched, and only the host
386386
// side func will be submitted to the GPE thread. Also the op won't be
387387
// captured by cudagraph.
388-
// Completion of the operation is tracked by exReq. exReq can be nullptr,
388+
// Completion of the operation is tracked by cpuFlag. cpuFlag can be nullptr,
389389
// indicating that the caller doesn't care about the completion of the
390390
// operation.
391391
commResult_t submitHost(
392392
std::vector<std::unique_ptr<struct OpElem>> opGroup,
393393
opFunc func,
394394
KernelConfig& kernelConfig,
395-
::ctran::CtranExRequestImpl* exReq);
395+
std::atomic_flag* cpuFlag);
396396

397397
// Allocate numElems number of p2pElem objects from internal pool.
398398
// When free objects are not enough, it will be in blocking wait and reclaim

comms/ctran/gpe/CtranGpeImpl.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -394,13 +394,13 @@ commResult_t CtranGpe::Impl::submitHost(
394394
std::vector<std::unique_ptr<struct OpElem>> opGroup,
395395
opFunc func,
396396
KernelConfig& kernelConfig,
397-
CtranExRequestImpl* exReq) {
397+
std::atomic_flag* cpuFlag) {
398398
// Enqueue op to gpeThread if any op is appended
399399
if (!opGroup.empty()) {
400400
class CtranGpeCmd* cmd = new class CtranGpeCmd;
401401
cmd->type = type;
402402
cmd->kernelFlag = nullptr;
403-
cmd->exReq = exReq;
403+
cmd->cpuFlag = cpuFlag;
404404

405405
if (type == CtranGpeCmd::TypeEnum::GRAPH_ENQUEUE) {
406406
cmd->coll.opGroup = std::move(opGroup);
@@ -531,8 +531,8 @@ void CtranGpe::Impl::gpeThreadFn() {
531531
// Ensure the host communication request completes, irrespective of
532532
// outcome of collective function - success, failure, or an exception
533533
SCOPE_EXIT {
534-
if (cmd->exReq) {
535-
cmd->exReq->complete();
534+
if (cmd->cpuFlag) {
535+
cmd->cpuFlag->test_and_set();
536536
}
537537
};
538538

comms/ctran/gpe/CtranGpeImpl.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,8 @@ class CtranGpeCmd {
123123

124124
// kernelFlag to assist device mem communication
125125
KernelFlagItem* kernelFlag{nullptr};
126-
// request to track completion of host mem communication
127-
::ctran::CtranExRequestImpl* exReq{nullptr};
126+
// cpuFlag to track completion of host mem communication
127+
std::atomic_flag* cpuFlag{nullptr};
128128

129129
bool persistent{false};
130130

@@ -202,7 +202,7 @@ class CtranGpe::Impl {
202202
std::vector<std::unique_ptr<struct OpElem>> opGroup,
203203
opFunc func,
204204
KernelConfig& kernelConfig,
205-
::ctran::CtranExRequestImpl* exReq);
205+
std::atomic_flag* cpuFlag);
206206

207207
// start the GPE thread.
208208
void start();

comms/ncclx/v2_27/meta/wrapper/CtranExComm.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,12 @@ ncclResult_t CtranExComm::broadcast(
139139

140140
// TODO: only tree supports host mem for now. Add support for direct.
141141
NCCLCHECK(metaCommToNccl(ctranComm->ctran_->algo->broadcastBinomialTree(
142-
sendbuff, recvbuff, count, ncclToMetaComm(datatype), root, reqImpl)));
142+
sendbuff,
143+
recvbuff,
144+
count,
145+
ncclToMetaComm(datatype),
146+
root,
147+
&reqImpl->bcast.complete)));
143148

144149
*req = reqPtr;
145150
return ncclSuccess;

0 commit comments

Comments
 (0)