@@ -39,15 +39,15 @@ std::string_view TorchComm::getCommName() const {
3939}
4040
4141// Point-to-Point Operations
42- std::shared_ptr <TorchWork> TorchComm::send (
42+ c10::intrusive_ptr <TorchWork> TorchComm::send (
4343 const at::Tensor& tensor,
4444 int dst,
4545 bool async_op,
4646 const SendOptions& options) {
4747 return impl_->send (tensor, dst, async_op, options);
4848}
4949
50- std::shared_ptr <TorchWork> TorchComm::recv (
50+ c10::intrusive_ptr <TorchWork> TorchComm::recv (
5151 at::Tensor& tensor,
5252 int src,
5353 bool async_op,
@@ -56,23 +56,23 @@ std::shared_ptr<TorchWork> TorchComm::recv(
5656}
5757
5858// Collective Operations
59- std::shared_ptr <TorchWork> TorchComm::broadcast (
59+ c10::intrusive_ptr <TorchWork> TorchComm::broadcast (
6060 at::Tensor& tensor,
6161 int root,
6262 bool async_op,
6363 const BroadcastOptions& options) {
6464 return impl_->broadcast (tensor, root, async_op, options);
6565}
6666
67- std::shared_ptr <TorchWork> TorchComm::all_reduce (
67+ c10::intrusive_ptr <TorchWork> TorchComm::all_reduce (
6868 at::Tensor& tensor,
6969 ReduceOp op,
7070 bool async_op,
7171 const AllReduceOptions& options) {
7272 return impl_->all_reduce (tensor, op, async_op, options);
7373}
7474
75- std::shared_ptr <TorchWork> TorchComm::reduce (
75+ c10::intrusive_ptr <TorchWork> TorchComm::reduce (
7676 const at::Tensor& tensor,
7777 int root,
7878 ReduceOp op,
@@ -81,31 +81,31 @@ std::shared_ptr<TorchWork> TorchComm::reduce(
8181 return impl_->reduce (tensor, root, op, async_op, options);
8282}
8383
84- std::shared_ptr <TorchWork> TorchComm::all_gather (
84+ c10::intrusive_ptr <TorchWork> TorchComm::all_gather (
8585 const std::vector<at::Tensor>& tensor_list,
8686 const at::Tensor& tensor,
8787 bool async_op,
8888 const AllGatherOptions& options) {
8989 return impl_->all_gather (tensor_list, tensor, async_op, options);
9090}
9191
92- std::shared_ptr <TorchWork> TorchComm::all_gather_v (
92+ c10::intrusive_ptr <TorchWork> TorchComm::all_gather_v (
9393 const std::vector<at::Tensor>& tensor_list,
9494 const at::Tensor& tensor,
9595 bool async_op,
9696 const AllGatherOptions& options) {
9797 return impl_->all_gather_v (tensor_list, tensor, async_op, options);
9898}
9999
100- std::shared_ptr <TorchWork> TorchComm::all_gather_single (
100+ c10::intrusive_ptr <TorchWork> TorchComm::all_gather_single (
101101 at::Tensor& output,
102102 const at::Tensor& input,
103103 bool async_op,
104104 const AllGatherSingleOptions& options) {
105105 return impl_->all_gather_single (output, input, async_op, options);
106106}
107107
108- std::shared_ptr <TorchWork> TorchComm::reduce_scatter (
108+ c10::intrusive_ptr <TorchWork> TorchComm::reduce_scatter (
109109 at::Tensor& output,
110110 const std::vector<at::Tensor>& input_list,
111111 ReduceOp op,
@@ -114,7 +114,7 @@ std::shared_ptr<TorchWork> TorchComm::reduce_scatter(
114114 return impl_->reduce_scatter (output, input_list, op, async_op, options);
115115}
116116
117- std::shared_ptr <TorchWork> TorchComm::reduce_scatter_v (
117+ c10::intrusive_ptr <TorchWork> TorchComm::reduce_scatter_v (
118118 at::Tensor& output,
119119 const std::vector<at::Tensor>& input_list,
120120 ReduceOp op,
@@ -123,7 +123,7 @@ std::shared_ptr<TorchWork> TorchComm::reduce_scatter_v(
123123 return impl_->reduce_scatter_v (output, input_list, op, async_op, options);
124124}
125125
126- std::shared_ptr <TorchWork> TorchComm::reduce_scatter_single (
126+ c10::intrusive_ptr <TorchWork> TorchComm::reduce_scatter_single (
127127 at::Tensor& output,
128128 const at::Tensor& input,
129129 ReduceOp op,
@@ -132,15 +132,15 @@ std::shared_ptr<TorchWork> TorchComm::reduce_scatter_single(
132132 return impl_->reduce_scatter_single (output, input, op, async_op, options);
133133}
134134
135- std::shared_ptr <TorchWork> TorchComm::all_to_all_single (
135+ c10::intrusive_ptr <TorchWork> TorchComm::all_to_all_single (
136136 at::Tensor& output,
137137 const at::Tensor& input,
138138 bool async_op,
139139 const AllToAllSingleOptions& options) {
140140 return impl_->all_to_all_single (output, input, async_op, options);
141141}
142142
143- std::shared_ptr <TorchWork> TorchComm::all_to_all_v_single (
143+ c10::intrusive_ptr <TorchWork> TorchComm::all_to_all_v_single (
144144 at::Tensor& output,
145145 const at::Tensor& input,
146146 const std::vector<uint64_t >& output_split_sizes,
@@ -151,7 +151,7 @@ std::shared_ptr<TorchWork> TorchComm::all_to_all_v_single(
151151 output, input, output_split_sizes, input_split_sizes, async_op, options);
152152}
153153
154- std::shared_ptr <TorchWork> TorchComm::all_to_all (
154+ c10::intrusive_ptr <TorchWork> TorchComm::all_to_all (
155155 const std::vector<at::Tensor>& output_tensor_list,
156156 const std::vector<at::Tensor>& input_tensor_list,
157157 bool async_op,
@@ -160,14 +160,14 @@ std::shared_ptr<TorchWork> TorchComm::all_to_all(
160160 output_tensor_list, input_tensor_list, async_op, options);
161161}
162162
163- std::shared_ptr <TorchWork> TorchComm::barrier (
163+ c10::intrusive_ptr <TorchWork> TorchComm::barrier (
164164 bool async_op,
165165 const BarrierOptions& options) {
166166 return impl_->barrier (async_op, options);
167167}
168168
169169// Scatter and Gather Operations
170- std::shared_ptr <TorchWork> TorchComm::scatter (
170+ c10::intrusive_ptr <TorchWork> TorchComm::scatter (
171171 at::Tensor& output_tensor,
172172 const std::vector<at::Tensor>& input_tensor_list,
173173 int root,
@@ -177,7 +177,7 @@ std::shared_ptr<TorchWork> TorchComm::scatter(
177177 output_tensor, input_tensor_list, root, async_op, options);
178178}
179179
180- std::shared_ptr <TorchWork> TorchComm::gather (
180+ c10::intrusive_ptr <TorchWork> TorchComm::gather (
181181 const std::vector<at::Tensor>& output_tensor_list,
182182 const at::Tensor& input_tensor,
183183 int root,
@@ -239,7 +239,7 @@ void BatchSendRecv::recv(at::Tensor& tensor, int src) {
239239 ops.push_back (op);
240240}
241241
242- std::shared_ptr <TorchWork> BatchSendRecv::issue (
242+ c10::intrusive_ptr <TorchWork> BatchSendRecv::issue (
243243 bool async_op,
244244 const BatchP2POptions& options) {
245245 return parent_->getBackendImpl ()->batch_op_issue (ops, async_op, options);
0 commit comments