Skip to content

Commit f6eaac8

Browse files
Chao1Handvrogozhmengfei25
authored
separate comm init from getXCClComm (#2090)
For split initXCCLComm and getXCCLComm. align with cuda. --------- Co-authored-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com> Co-authored-by: mengfei25 <mengfei.li@Intel.com>
1 parent 91665cb commit f6eaac8

File tree

2 files changed

+22
-10
lines changed

2 files changed

+22
-10
lines changed

src/xccl/ProcessGroupXCCL.cpp

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,17 @@ c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> ProcessGroupXCCL::initWork(
502502
}
503503

504504
std::shared_ptr<xcclComm_t> ProcessGroupXCCL::getXCCLComm(
505+
const std::string& deviceKey) {
506+
std::lock_guard<std::mutex> lock(mutex_);
507+
auto it = devXCCLCommMap_.find(deviceKey);
508+
if (it != devXCCLCommMap_.end()) {
509+
// Reuse the cached communicator if there is one.
510+
return it->second;
511+
}
512+
return nullptr;
513+
}
514+
515+
std::shared_ptr<xcclComm_t> ProcessGroupXCCL::initXCCLComm(
505516
const std::string& deviceKey,
506517
at::Device& device,
507518
OpType opType,
@@ -516,13 +527,6 @@ std::shared_ptr<xcclComm_t> ProcessGroupXCCL::getXCCLComm(
516527

517528
usedDeviceIdxs_.insert(device.index());
518529

519-
{
520-
std::lock_guard<std::mutex> lock(mutex_);
521-
if (devXCCLCommMap_.find(deviceKey) != devXCCLCommMap_.end()) {
522-
return devXCCLCommMap_[deviceKey];
523-
}
524-
}
525-
526530
std::shared_ptr<xcclComm_t> XCCLComm;
527531

528532
bool batchP2P = xcclActiveGroupCounter_ > 0;
@@ -680,7 +684,10 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::collective(
680684
nanCheck &= enableNanCheck_;
681685
auto device = inputs[0].device();
682686
const auto key = std::to_string(device.index());
683-
auto comm = getXCCLComm(key, device, opType);
687+
std::shared_ptr<xcclComm_t> comm = getXCCLComm(key);
688+
if (comm == nullptr) {
689+
comm = initXCCLComm(key, device, opType);
690+
}
684691

685692
if (!coalescing_state_) {
686693
seqCollective_++;
@@ -846,7 +853,10 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::pointToPoint(
846853
}
847854

848855
op_id_++;
849-
auto comm = getXCCLComm(key, device, opType, p2pRank, isSendRecvSelf);
856+
std::shared_ptr<xcclComm_t> comm = getXCCLComm(key);
857+
if (comm == nullptr) {
858+
comm = initXCCLComm(key, device, opType, p2pRank, isSendRecvSelf);
859+
}
850860

851861
if (coalescing_state_ & CoalActive) {
852862
if ((coalescing_state_ & CoalP2P) == 0) {

src/xccl/ProcessGroupXCCL.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,9 @@ class TORCH_API ProcessGroupXCCL : public Backend {
166166

167167
c10::intrusive_ptr<Work> endCoalescing(OpType optype);
168168

169-
std::shared_ptr<xcclComm_t> getXCCLComm(
169+
std::shared_ptr<xcclComm_t> getXCCLComm(const std::string& deviceKey);
170+
171+
std::shared_ptr<xcclComm_t> initXCCLComm(
170172
const std::string& deviceKey,
171173
at::Device& device,
172174
OpType opType,

0 commit comments

Comments
 (0)