Skip to content

Commit 73a225d

Browse files
tianfengfrankmeta-codesync[bot]
authored andcommitted
enable all_gather_v support
Summary: tp_overlapping requires to work with uneven_split introduced by D84788079. To support that, we need all_gather_v in torchcomm - enable all_gather_v to support various tensor size of output_tensor list - add both cpp/py integration UTs Reviewed By: d4l3k Differential Revision: D85292529 fbshipit-source-id: 27b281de41121b7887e55248591881503979680a
1 parent f9a45be commit 73a225d

20 files changed

+452
-0
lines changed

comms/torchcomms/TorchComm.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,14 @@ std::shared_ptr<TorchWork> TorchComm::all_gather(
8989
return impl_->all_gather(tensor_list, tensor, async_op, options);
9090
}
9191

92+
std::shared_ptr<TorchWork> TorchComm::all_gather_v(
93+
const std::vector<at::Tensor>& tensor_list,
94+
const at::Tensor& tensor,
95+
bool async_op,
96+
const AllGatherOptions& options) {
97+
return impl_->all_gather_v(tensor_list, tensor, async_op, options);
98+
}
99+
92100
std::shared_ptr<TorchWork> TorchComm::all_gather_single(
93101
at::Tensor& output,
94102
const at::Tensor& input,

comms/torchcomms/TorchComm.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,11 @@ class TorchComm {
6565
const at::Tensor& tensor,
6666
bool async_op,
6767
const AllGatherOptions& options = {});
68+
std::shared_ptr<TorchWork> all_gather_v(
69+
const std::vector<at::Tensor>& tensor_list,
70+
const at::Tensor& tensor,
71+
bool async_op,
72+
const AllGatherOptions& options = {});
6873
std::shared_ptr<TorchWork> all_gather_single(
6974
at::Tensor& output,
7075
const at::Tensor& input,

comms/torchcomms/TorchCommBackend.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,11 @@ class TorchCommBackend {
7878
const at::Tensor& tensor,
7979
bool async_op,
8080
const AllGatherOptions& options = {}) = 0;
81+
virtual std::shared_ptr<TorchWork> all_gather_v(
82+
const std::vector<at::Tensor>& tensor_list,
83+
const at::Tensor& tensor,
84+
bool async_op,
85+
const AllGatherOptions& options = {}) = 0;
8186
virtual std::shared_ptr<TorchWork> all_gather_single(
8287
at::Tensor& output,
8388
const at::Tensor& input,

comms/torchcomms/TorchCommPy.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -640,6 +640,41 @@ Output will be available on all ranks.
640640
tensor: the input tensor to share
641641
async_op: whether to perform the operation asynchronously
642642
hints: dictionary of string hints for backend-specific options
643+
timeout: timeout for the operation
644+
)",
645+
py::arg("tensor_list"),
646+
py::arg("tensor"),
647+
py::arg("async_op"),
648+
py::arg("hints") = std::nullopt,
649+
py::arg("timeout") = std::nullopt,
650+
py::call_guard<py::gil_scoped_release>())
651+
.def(
652+
"all_gather_v",
653+
[](TorchComm& self,
654+
const std::vector<at::Tensor>& tensor_list,
655+
const at::Tensor& tensor,
656+
bool async_op,
657+
std::optional<std::unordered_map<std::string, std::string>> hints,
658+
std::optional<std::chrono::milliseconds> timeout) {
659+
AllGatherOptions opts;
660+
if (hints) {
661+
opts.hints = *hints;
662+
}
663+
if (timeout) {
664+
opts.timeout = *timeout;
665+
}
666+
return self.all_gather_v(tensor_list, tensor, async_op, opts);
667+
},
668+
R"(
669+
Gather a tensor from all ranks in the communicator, supporting variable tensor sizes per rank.
670+
671+
Output will be available on all ranks.
672+
673+
Args:
674+
tensor_list: the list of tensors to gather into; the list is the same on all ranks, but tensor sizes may differ between indices.
675+
tensor: the input tensor to share; size may differ per rank.
676+
async_op: whether to perform the operation asynchronously
677+
hints: dictionary of string hints for backend-specific options
643678
timeout: timeout for the operation
644679
)",
645680
py::arg("tensor_list"),

comms/torchcomms/_comms.pyi

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,14 @@ class TorchComm:
236236
hints: Dict[str, str] | None = None,
237237
timeout: timedelta | None = None,
238238
) -> TorchWork: ...
239+
def all_gather_v(
240+
self,
241+
tensor_list: List[Any],
242+
tensor: Any,
243+
async_op: bool,
244+
hints: Dict[str, str] | None = None,
245+
timeout: timedelta | None = None,
246+
) -> TorchWork: ...
239247
def all_gather_single(
240248
self,
241249
output: Any,

comms/torchcomms/gloo/TorchCommGloo.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -719,6 +719,14 @@ std::shared_ptr<TorchWork> TorchCommGloo::all_gather(
719719
async_op);
720720
}
721721

722+
std::shared_ptr<TorchWork> TorchCommGloo::all_gather_v(
723+
const std::vector<at::Tensor>& tensor_list,
724+
const at::Tensor& tensor,
725+
bool async_op,
726+
const AllGatherOptions& options) {
727+
throw std::runtime_error("all_gather_v is not supported in GLOO backend yet");
728+
}
729+
722730
std::shared_ptr<TorchWork> TorchCommGloo::all_gather_single(
723731
at::Tensor& output,
724732
const at::Tensor& input,

comms/torchcomms/gloo/TorchCommGloo.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,11 @@ class TorchCommGloo : public TorchCommBackend,
9090
const at::Tensor& tensor,
9191
bool async_op,
9292
const AllGatherOptions& options = {}) override;
93+
std::shared_ptr<TorchWork> all_gather_v(
94+
const std::vector<at::Tensor>& tensor_list,
95+
const at::Tensor& tensor,
96+
bool async_op,
97+
const AllGatherOptions& options = {}) override;
9398
std::shared_ptr<TorchWork> all_gather_single(
9499
at::Tensor& output,
95100
const at::Tensor& input,

comms/torchcomms/nccl/TorchCommNCCL.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -691,6 +691,14 @@ std::shared_ptr<TorchWork> TorchCommNCCL::all_gather(
691691
return work;
692692
}
693693

694+
std::shared_ptr<TorchWork> TorchCommNCCL::all_gather_v(
695+
const std::vector<at::Tensor>& tensor_list,
696+
const at::Tensor& tensor,
697+
bool async_op,
698+
const AllGatherOptions& options) {
699+
throw std::runtime_error("all_gather_v is not supported in NCCL backend");
700+
}
701+
694702
std::shared_ptr<TorchWork> TorchCommNCCL::all_gather_single(
695703
at::Tensor& output,
696704
const at::Tensor& input,

comms/torchcomms/nccl/TorchCommNCCL.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,11 @@ class TorchCommNCCL : public TorchCommBackend,
107107
const at::Tensor& tensor,
108108
bool async_op,
109109
const AllGatherOptions& options = {}) override;
110+
std::shared_ptr<TorchWork> all_gather_v(
111+
const std::vector<at::Tensor>& tensor_list,
112+
const at::Tensor& tensor,
113+
bool async_op,
114+
const AllGatherOptions& options = {}) override;
110115
std::shared_ptr<TorchWork> all_gather_single(
111116
at::Tensor& output,
112117
const at::Tensor& input,

comms/torchcomms/ncclx/TorchCommNCCLX.cpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -710,6 +710,65 @@ std::shared_ptr<TorchWork> TorchCommNCCLX::all_gather(
710710
return work;
711711
}
712712

713+
std::shared_ptr<TorchWork> TorchCommNCCLX::all_gather_v(
714+
const std::vector<at::Tensor>& tensor_list,
715+
const at::Tensor& tensor,
716+
bool async_op,
717+
const AllGatherOptions& options) {
718+
checkInitialized();
719+
checkAndAbortIfTimedOutOrError();
720+
if (tensor_list.size() != static_cast<size_t>(comm_size_)) {
721+
throw std::runtime_error(
722+
"tensor_list size must equal comm_size for all_gather");
723+
}
724+
725+
// Ensure input tensor is contiguous
726+
ensureTensorContiguous(tensor);
727+
728+
for (const auto& t : tensor_list) {
729+
ensureTensorContiguous(t);
730+
}
731+
TorchCommTracingGuard tracingGuard(
732+
name_, comm_size_, "all_gather_v", rank_, tensor_list, {tensor});
733+
734+
cudaStream_t stream = getOperationStream(async_op);
735+
auto work = createWork(
736+
stream, getOperationTimeout(options.timeout, options_.timeout), {tensor});
737+
738+
work->recordStart();
739+
740+
// Use multiple broadcast operations for all_gather
741+
nccl_api_->groupStart();
742+
743+
for (int i = 0; i < comm_size_; ++i) {
744+
// assign inpu/output tensors to support vector all_gather (all_gather_v)
745+
// where unevenly sized inputs are gathered among participating ranks
746+
auto& output = tensor_list[i];
747+
auto& input = (i == rank_) ? tensor : output;
748+
if (input.numel() != output.numel()) {
749+
throw std::runtime_error(
750+
"Output tensor size must equal input tensor size for all_gather_v");
751+
}
752+
nccl_api_->broadcast(
753+
input.data_ptr(),
754+
output.data_ptr(),
755+
input.numel(),
756+
getNcclDataType(output),
757+
i,
758+
nccl_comm_,
759+
stream);
760+
}
761+
762+
nccl_api_->groupEnd();
763+
764+
work->recordEnd();
765+
766+
// Enqueue the work after events have been recorded
767+
enqueueWork(work, stream);
768+
769+
return work;
770+
}
771+
713772
std::shared_ptr<TorchWork> TorchCommNCCLX::all_gather_single(
714773
at::Tensor& output,
715774
const at::Tensor& input,

0 commit comments

Comments
 (0)