Skip to content

Commit 7b52531

Browse files
siyengarmeta-codesync[bot]
authored andcommitted
change holder of work to intrusive ptr from shared ptr (#37)
Summary: Pull Request resolved: #37 Change holder type of Work objects to an instrusive_ptr. This will reduce the overhead of pybind return value wrapping logic. Previously it need to hold a shared_ptr and wrap it. intrsuive_ptr does not need this wrapping any more. This will bring this overhead closer to the processgroup. Reviewed By: d4l3k Differential Revision: D85934777 fbshipit-source-id: 12ab355a76da6ac09fce0553a37a35a734c96a03
1 parent f0a4112 commit 7b52531

35 files changed

+330
-322
lines changed

comms/torchcomms/BackendWrapper.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ std::vector<uint64_t> toVecUint64(const std::vector<int64_t>& vec) {
5151

5252
} // namespace
5353

54-
WorkWrapper::WorkWrapper(std::shared_ptr<TorchWork> work)
54+
WorkWrapper::WorkWrapper(c10::intrusive_ptr<TorchWork> work)
5555
: work_(std::move(work)) {}
5656

5757
bool WorkWrapper::isCompleted() {

comms/torchcomms/BackendWrapper.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ namespace comms {
1212

1313
class WorkWrapper : public c10d::Work {
1414
public:
15-
explicit WorkWrapper(std::shared_ptr<TorchWork> work);
15+
explicit WorkWrapper(c10::intrusive_ptr<TorchWork> work);
1616
~WorkWrapper() override = default;
1717

1818
bool isCompleted() override;
@@ -23,7 +23,7 @@ class WorkWrapper : public c10d::Work {
2323
std::vector<at::Tensor> result() override;
2424

2525
private:
26-
std::shared_ptr<TorchWork> work_;
26+
c10::intrusive_ptr<TorchWork> work_;
2727
};
2828

2929
using c10d::kUnsetTimeout;

comms/torchcomms/TorchComm.cpp

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,15 @@ std::string_view TorchComm::getCommName() const {
3939
}
4040

4141
// Point-to-Point Operations
42-
std::shared_ptr<TorchWork> TorchComm::send(
42+
c10::intrusive_ptr<TorchWork> TorchComm::send(
4343
const at::Tensor& tensor,
4444
int dst,
4545
bool async_op,
4646
const SendOptions& options) {
4747
return impl_->send(tensor, dst, async_op, options);
4848
}
4949

50-
std::shared_ptr<TorchWork> TorchComm::recv(
50+
c10::intrusive_ptr<TorchWork> TorchComm::recv(
5151
at::Tensor& tensor,
5252
int src,
5353
bool async_op,
@@ -56,23 +56,23 @@ std::shared_ptr<TorchWork> TorchComm::recv(
5656
}
5757

5858
// Collective Operations
59-
std::shared_ptr<TorchWork> TorchComm::broadcast(
59+
c10::intrusive_ptr<TorchWork> TorchComm::broadcast(
6060
at::Tensor& tensor,
6161
int root,
6262
bool async_op,
6363
const BroadcastOptions& options) {
6464
return impl_->broadcast(tensor, root, async_op, options);
6565
}
6666

67-
std::shared_ptr<TorchWork> TorchComm::all_reduce(
67+
c10::intrusive_ptr<TorchWork> TorchComm::all_reduce(
6868
at::Tensor& tensor,
6969
ReduceOp op,
7070
bool async_op,
7171
const AllReduceOptions& options) {
7272
return impl_->all_reduce(tensor, op, async_op, options);
7373
}
7474

75-
std::shared_ptr<TorchWork> TorchComm::reduce(
75+
c10::intrusive_ptr<TorchWork> TorchComm::reduce(
7676
const at::Tensor& tensor,
7777
int root,
7878
ReduceOp op,
@@ -81,31 +81,31 @@ std::shared_ptr<TorchWork> TorchComm::reduce(
8181
return impl_->reduce(tensor, root, op, async_op, options);
8282
}
8383

84-
std::shared_ptr<TorchWork> TorchComm::all_gather(
84+
c10::intrusive_ptr<TorchWork> TorchComm::all_gather(
8585
const std::vector<at::Tensor>& tensor_list,
8686
const at::Tensor& tensor,
8787
bool async_op,
8888
const AllGatherOptions& options) {
8989
return impl_->all_gather(tensor_list, tensor, async_op, options);
9090
}
9191

92-
std::shared_ptr<TorchWork> TorchComm::all_gather_v(
92+
c10::intrusive_ptr<TorchWork> TorchComm::all_gather_v(
9393
const std::vector<at::Tensor>& tensor_list,
9494
const at::Tensor& tensor,
9595
bool async_op,
9696
const AllGatherOptions& options) {
9797
return impl_->all_gather_v(tensor_list, tensor, async_op, options);
9898
}
9999

100-
std::shared_ptr<TorchWork> TorchComm::all_gather_single(
100+
c10::intrusive_ptr<TorchWork> TorchComm::all_gather_single(
101101
at::Tensor& output,
102102
const at::Tensor& input,
103103
bool async_op,
104104
const AllGatherSingleOptions& options) {
105105
return impl_->all_gather_single(output, input, async_op, options);
106106
}
107107

108-
std::shared_ptr<TorchWork> TorchComm::reduce_scatter(
108+
c10::intrusive_ptr<TorchWork> TorchComm::reduce_scatter(
109109
at::Tensor& output,
110110
const std::vector<at::Tensor>& input_list,
111111
ReduceOp op,
@@ -114,7 +114,7 @@ std::shared_ptr<TorchWork> TorchComm::reduce_scatter(
114114
return impl_->reduce_scatter(output, input_list, op, async_op, options);
115115
}
116116

117-
std::shared_ptr<TorchWork> TorchComm::reduce_scatter_v(
117+
c10::intrusive_ptr<TorchWork> TorchComm::reduce_scatter_v(
118118
at::Tensor& output,
119119
const std::vector<at::Tensor>& input_list,
120120
ReduceOp op,
@@ -123,7 +123,7 @@ std::shared_ptr<TorchWork> TorchComm::reduce_scatter_v(
123123
return impl_->reduce_scatter_v(output, input_list, op, async_op, options);
124124
}
125125

126-
std::shared_ptr<TorchWork> TorchComm::reduce_scatter_single(
126+
c10::intrusive_ptr<TorchWork> TorchComm::reduce_scatter_single(
127127
at::Tensor& output,
128128
const at::Tensor& input,
129129
ReduceOp op,
@@ -132,15 +132,15 @@ std::shared_ptr<TorchWork> TorchComm::reduce_scatter_single(
132132
return impl_->reduce_scatter_single(output, input, op, async_op, options);
133133
}
134134

135-
std::shared_ptr<TorchWork> TorchComm::all_to_all_single(
135+
c10::intrusive_ptr<TorchWork> TorchComm::all_to_all_single(
136136
at::Tensor& output,
137137
const at::Tensor& input,
138138
bool async_op,
139139
const AllToAllSingleOptions& options) {
140140
return impl_->all_to_all_single(output, input, async_op, options);
141141
}
142142

143-
std::shared_ptr<TorchWork> TorchComm::all_to_all_v_single(
143+
c10::intrusive_ptr<TorchWork> TorchComm::all_to_all_v_single(
144144
at::Tensor& output,
145145
const at::Tensor& input,
146146
const std::vector<uint64_t>& output_split_sizes,
@@ -151,7 +151,7 @@ std::shared_ptr<TorchWork> TorchComm::all_to_all_v_single(
151151
output, input, output_split_sizes, input_split_sizes, async_op, options);
152152
}
153153

154-
std::shared_ptr<TorchWork> TorchComm::all_to_all(
154+
c10::intrusive_ptr<TorchWork> TorchComm::all_to_all(
155155
const std::vector<at::Tensor>& output_tensor_list,
156156
const std::vector<at::Tensor>& input_tensor_list,
157157
bool async_op,
@@ -160,14 +160,14 @@ std::shared_ptr<TorchWork> TorchComm::all_to_all(
160160
output_tensor_list, input_tensor_list, async_op, options);
161161
}
162162

163-
std::shared_ptr<TorchWork> TorchComm::barrier(
163+
c10::intrusive_ptr<TorchWork> TorchComm::barrier(
164164
bool async_op,
165165
const BarrierOptions& options) {
166166
return impl_->barrier(async_op, options);
167167
}
168168

169169
// Scatter and Gather Operations
170-
std::shared_ptr<TorchWork> TorchComm::scatter(
170+
c10::intrusive_ptr<TorchWork> TorchComm::scatter(
171171
at::Tensor& output_tensor,
172172
const std::vector<at::Tensor>& input_tensor_list,
173173
int root,
@@ -177,7 +177,7 @@ std::shared_ptr<TorchWork> TorchComm::scatter(
177177
output_tensor, input_tensor_list, root, async_op, options);
178178
}
179179

180-
std::shared_ptr<TorchWork> TorchComm::gather(
180+
c10::intrusive_ptr<TorchWork> TorchComm::gather(
181181
const std::vector<at::Tensor>& output_tensor_list,
182182
const at::Tensor& input_tensor,
183183
int root,
@@ -239,7 +239,7 @@ void BatchSendRecv::recv(at::Tensor& tensor, int src) {
239239
ops.push_back(op);
240240
}
241241

242-
std::shared_ptr<TorchWork> BatchSendRecv::issue(
242+
c10::intrusive_ptr<TorchWork> BatchSendRecv::issue(
243243
bool async_op,
244244
const BatchP2POptions& options) {
245245
return parent_->getBackendImpl()->batch_op_issue(ops, async_op, options);

comms/torchcomms/TorchComm.hpp

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -32,96 +32,96 @@ class TorchComm {
3232
std::string_view getCommName() const;
3333

3434
// Point-to-Point Operations
35-
std::shared_ptr<TorchWork> send(
35+
c10::intrusive_ptr<TorchWork> send(
3636
const at::Tensor& tensor,
3737
int dst,
3838
bool async_op,
3939
const SendOptions& options = {});
40-
std::shared_ptr<TorchWork> recv(
40+
c10::intrusive_ptr<TorchWork> recv(
4141
at::Tensor& tensor,
4242
int src,
4343
bool async_op,
4444
const RecvOptions& options = {});
4545

4646
// Collective Operations
47-
std::shared_ptr<TorchWork> broadcast(
47+
c10::intrusive_ptr<TorchWork> broadcast(
4848
at::Tensor& tensor,
4949
int root,
5050
bool async_op,
5151
const BroadcastOptions& options = {});
52-
std::shared_ptr<TorchWork> all_reduce(
52+
c10::intrusive_ptr<TorchWork> all_reduce(
5353
at::Tensor& tensor,
5454
ReduceOp op,
5555
bool async_op,
5656
const AllReduceOptions& options = {});
57-
std::shared_ptr<TorchWork> reduce(
57+
c10::intrusive_ptr<TorchWork> reduce(
5858
const at::Tensor& tensor,
5959
int root,
6060
ReduceOp op,
6161
bool async_op,
6262
const ReduceOptions& options = {});
63-
std::shared_ptr<TorchWork> all_gather(
63+
c10::intrusive_ptr<TorchWork> all_gather(
6464
const std::vector<at::Tensor>& tensor_list,
6565
const at::Tensor& tensor,
6666
bool async_op,
6767
const AllGatherOptions& options = {});
68-
std::shared_ptr<TorchWork> all_gather_v(
68+
c10::intrusive_ptr<TorchWork> all_gather_v(
6969
const std::vector<at::Tensor>& tensor_list,
7070
const at::Tensor& tensor,
7171
bool async_op,
7272
const AllGatherOptions& options = {});
73-
std::shared_ptr<TorchWork> all_gather_single(
73+
c10::intrusive_ptr<TorchWork> all_gather_single(
7474
at::Tensor& output,
7575
const at::Tensor& input,
7676
bool async_op,
7777
const AllGatherSingleOptions& options = {});
78-
std::shared_ptr<TorchWork> reduce_scatter(
78+
c10::intrusive_ptr<TorchWork> reduce_scatter(
7979
at::Tensor& output,
8080
const std::vector<at::Tensor>& input_list,
8181
ReduceOp op,
8282
bool async_op,
8383
const ReduceScatterOptions& options = {});
84-
std::shared_ptr<TorchWork> reduce_scatter_v(
84+
c10::intrusive_ptr<TorchWork> reduce_scatter_v(
8585
at::Tensor& output,
8686
const std::vector<at::Tensor>& input_list,
8787
ReduceOp op,
8888
bool async_op,
8989
const ReduceScatterOptions& options = {});
90-
std::shared_ptr<TorchWork> reduce_scatter_single(
90+
c10::intrusive_ptr<TorchWork> reduce_scatter_single(
9191
at::Tensor& output,
9292
const at::Tensor& input,
9393
ReduceOp op,
9494
bool async_op,
9595
const ReduceScatterSingleOptions& options = {});
96-
std::shared_ptr<TorchWork> all_to_all_single(
96+
c10::intrusive_ptr<TorchWork> all_to_all_single(
9797
at::Tensor& output,
9898
const at::Tensor& input,
9999
bool async_op,
100100
const AllToAllSingleOptions& options = {});
101-
std::shared_ptr<TorchWork> all_to_all_v_single(
101+
c10::intrusive_ptr<TorchWork> all_to_all_v_single(
102102
at::Tensor& output,
103103
const at::Tensor& input,
104104
const std::vector<uint64_t>& output_split_sizes,
105105
const std::vector<uint64_t>& input_split_sizes,
106106
bool async_op,
107107
const AllToAllvSingleOptions& options = {});
108-
std::shared_ptr<TorchWork> all_to_all(
108+
c10::intrusive_ptr<TorchWork> all_to_all(
109109
const std::vector<at::Tensor>& output_tensor_list,
110110
const std::vector<at::Tensor>& input_tensor_list,
111111
bool async_op,
112112
const AllToAllOptions& options = {});
113-
std::shared_ptr<TorchWork> barrier(
113+
c10::intrusive_ptr<TorchWork> barrier(
114114
bool async_op,
115115
const BarrierOptions& options = {});
116116

117117
// Scatter and Gather Operations
118-
std::shared_ptr<TorchWork> scatter(
118+
c10::intrusive_ptr<TorchWork> scatter(
119119
at::Tensor& output_tensor,
120120
const std::vector<at::Tensor>& input_tensor_list,
121121
int root,
122122
bool async_op,
123123
const ScatterOptions& options = {});
124-
std::shared_ptr<TorchWork> gather(
124+
c10::intrusive_ptr<TorchWork> gather(
125125
const std::vector<at::Tensor>& output_tensor_list,
126126
const at::Tensor& input_tensor,
127127
int root,

0 commit comments

Comments
 (0)