diff --git a/comms/torchcomms/nccl/TorchCommNCCL.hpp b/comms/torchcomms/nccl/TorchCommNCCL.hpp index b6b21d93..7288f317 100644 --- a/comms/torchcomms/nccl/TorchCommNCCL.hpp +++ b/comms/torchcomms/nccl/TorchCommNCCL.hpp @@ -299,7 +299,7 @@ class TorchCommNCCL : public TorchCommBackend, void timeoutWatchdog() noexcept; void checkInitialized() const; void checkAndAbortIfTimedOutOrError(); - void checkWorkQueue(bool isMainThread); + void checkWorkQueue(); void enqueueWork(std::shared_ptr work, cudaStream_t stream); bool getGraphCaptureMode(); cudaStream_t getOperationStream(bool async_op); diff --git a/comms/torchcomms/nccl/TorchCommNCCLUtils.cpp b/comms/torchcomms/nccl/TorchCommNCCLUtils.cpp index 2f9e8f12..36356127 100644 --- a/comms/torchcomms/nccl/TorchCommNCCLUtils.cpp +++ b/comms/torchcomms/nccl/TorchCommNCCLUtils.cpp @@ -175,8 +175,8 @@ TorchCommNCCL::RedOpRAII TorchCommNCCL::getNcclReduceOp( } } -void TorchCommNCCL::checkWorkQueue(bool isMainThread) { - TorchWorkNCCL::WorkStatus status = workq_.garbageCollect(isMainThread); +void TorchCommNCCL::checkWorkQueue() { + TorchWorkNCCL::WorkStatus status = workq_.garbageCollect(); switch (status) { case TorchWorkNCCL::WorkStatus::TIMEDOUT: @@ -210,7 +210,7 @@ void TorchCommNCCL::timeoutWatchdog() noexcept { } // Check work objects for completion or timeout - checkWorkQueue(false); + checkWorkQueue(); if (comm_state_ != CommState::NORMAL && options_.abort_process_on_timeout_or_error) { // Log the error and abort the process. We cannot abort the NCCL @@ -243,7 +243,7 @@ void TorchCommNCCL::checkAndAbortIfTimedOutOrError() { } // First, check work queue status - checkWorkQueue(true); + checkWorkQueue(); if (comm_state_ == CommState::TIMEOUT) { abortNcclComm(); diff --git a/comms/torchcomms/nccl/TorchWorkNCCL.hpp b/comms/torchcomms/nccl/TorchWorkNCCL.hpp index 011e5b6d..351e39f9 100644 --- a/comms/torchcomms/nccl/TorchWorkNCCL.hpp +++ b/comms/torchcomms/nccl/TorchWorkNCCL.hpp @@ -91,7 +91,7 @@ class TorchWorkNCCLQueue { TorchWorkNCCLQueue() = default; ~TorchWorkNCCLQueue() = default; - TorchWorkNCCL::WorkStatus garbageCollect(bool isMainThread); + TorchWorkNCCL::WorkStatus garbageCollect(); // Finalize function can only be called from the main thread TorchWorkNCCL::WorkStatus finalize(); void enqueueWork(std::shared_ptr work, cudaStream_t stream); @@ -99,7 +99,6 @@ class TorchWorkNCCLQueue { private: std::unordered_map>> stream_work_queues_; - std::vector> completed_work_queue_; std::recursive_mutex work_queues_mutex_; }; diff --git a/comms/torchcomms/nccl/TorchWorkNCCLQueue.cpp b/comms/torchcomms/nccl/TorchWorkNCCLQueue.cpp index 3d726ca3..0cd1e718 100644 --- a/comms/torchcomms/nccl/TorchWorkNCCLQueue.cpp +++ b/comms/torchcomms/nccl/TorchWorkNCCLQueue.cpp @@ -5,8 +5,7 @@ namespace torch { namespace comms { -TorchWorkNCCL::WorkStatus TorchWorkNCCLQueue::garbageCollect( - bool isMainThread) { +TorchWorkNCCL::WorkStatus TorchWorkNCCLQueue::garbageCollect() { std::lock_guard lock(work_queues_mutex_); TorchWorkNCCL::WorkStatus last_status = TorchWorkNCCL::WorkStatus::COMPLETED; @@ -29,7 +28,6 @@ TorchWorkNCCL::WorkStatus TorchWorkNCCLQueue::garbageCollect( if (status == TorchWorkNCCL::WorkStatus::COMPLETED) { // Work is completed, remove it from the work queue work_queue.pop(); - completed_work_queue_.push_back(work); // Continue to the next element in the queue } else if ( status == TorchWorkNCCL::WorkStatus::TIMEDOUT || @@ -50,11 +48,6 @@ TorchWorkNCCL::WorkStatus TorchWorkNCCLQueue::garbageCollect( } } - if (isMainThread) { - // If we are the main thread, clear the completed work queues - completed_work_queue_.clear(); - } - return last_status; } @@ -70,7 +63,7 @@ TorchWorkNCCL::WorkStatus TorchWorkNCCLQueue::finalize() { // empty TorchWorkNCCL::WorkStatus status = TorchWorkNCCL::WorkStatus::COMPLETED; while (!stream_work_queues_.empty()) { - status = garbageCollect(true); + status = garbageCollect(); if (status == TorchWorkNCCL::WorkStatus::ERROR || status == TorchWorkNCCL::WorkStatus::TIMEDOUT || status == TorchWorkNCCL::WorkStatus::COMPLETED) { @@ -83,7 +76,6 @@ TorchWorkNCCL::WorkStatus TorchWorkNCCLQueue::finalize() { // NOTE: finalize MUST return without holding references to any work object, // otherwise it may leak object and cause side effects. stream_work_queues_.clear(); - completed_work_queue_.clear(); return status; }