@@ -502,6 +502,17 @@ c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> ProcessGroupXCCL::initWork(
502502}
503503
504504std::shared_ptr<xcclComm_t> ProcessGroupXCCL::getXCCLComm (
505+ const std::string& deviceKey) {
506+ std::lock_guard<std::mutex> lock (mutex_);
507+ auto it = devXCCLCommMap_.find (deviceKey);
508+ if (it != devXCCLCommMap_.end ()) {
509+ // Reuse the cached communicator if there is one.
510+ return it->second ;
511+ }
512+ return nullptr ;
513+ }
514+
515+ std::shared_ptr<xcclComm_t> ProcessGroupXCCL::initXCCLComm (
505516 const std::string& deviceKey,
506517 at::Device& device,
507518 OpType opType,
@@ -516,13 +527,6 @@ std::shared_ptr<xcclComm_t> ProcessGroupXCCL::getXCCLComm(
516527
517528 usedDeviceIdxs_.insert (device.index ());
518529
519- {
520- std::lock_guard<std::mutex> lock (mutex_);
521- if (devXCCLCommMap_.find (deviceKey) != devXCCLCommMap_.end ()) {
522- return devXCCLCommMap_[deviceKey];
523- }
524- }
525-
526530 std::shared_ptr<xcclComm_t> XCCLComm;
527531
528532 bool batchP2P = xcclActiveGroupCounter_ > 0 ;
@@ -680,7 +684,10 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::collective(
680684 nanCheck &= enableNanCheck_;
681685 auto device = inputs[0 ].device ();
682686 const auto key = std::to_string (device.index ());
683- auto comm = getXCCLComm (key, device, opType);
687+ std::shared_ptr<xcclComm_t> comm = getXCCLComm (key);
688+ if (comm == nullptr ) {
689+ comm = initXCCLComm (key, device, opType);
690+ }
684691
685692 if (!coalescing_state_) {
686693 seqCollective_++;
@@ -846,7 +853,10 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::pointToPoint(
846853 }
847854
848855 op_id_++;
849- auto comm = getXCCLComm (key, device, opType, p2pRank, isSendRecvSelf);
856+ std::shared_ptr<xcclComm_t> comm = getXCCLComm (key);
857+ if (comm == nullptr ) {
858+ comm = initXCCLComm (key, device, opType, p2pRank, isSendRecvSelf);
859+ }
850860
851861 if (coalescing_state_ & CoalActive) {
852862 if ((coalescing_state_ & CoalP2P) == 0 ) {
0 commit comments