Skip to content

Commit 15f0d4a

Browse files
committed
Add host_collective_ops to TF replacing Horovod
Summary: Add host collective ops Refactor of PopDist ops Reviewers: #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, samuelh Reviewed By: #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, samuelh Subscribers: samuelh Maniphest Tasks: T64481, T66961 Differential Revision: https://phabricator.sourcevertex.net/D73950
1 parent 86803ca commit 15f0d4a

File tree

14 files changed

+384
-93
lines changed

14 files changed

+384
-93
lines changed

tensorflow/compiler/plugin/poplar/driver/popit_backend/popit_executor.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ Status PopItExecutor::Memset32(se::Stream* stream,
229229
}
230230
bool PopItExecutor::Memcpy(se::Stream* stream, void* host_dst,
231231
const se::DeviceMemoryBase& src, uint64 size) {
232-
return RunPoplarFunction([&] {
232+
return RunPoplarFunction<poplar::poplar_error>([&] {
233233
return popitCopyToHost(
234234
static_cast<const PopItSubBuffer*>(src.opaque())->GetDevicePtr(),
235235
static_cast<char*>(host_dst));
@@ -238,7 +238,7 @@ bool PopItExecutor::Memcpy(se::Stream* stream, void* host_dst,
238238
}
239239
bool PopItExecutor::Memcpy(se::Stream* stream, se::DeviceMemoryBase* dst,
240240
const void* host_src, uint64 size) {
241-
return RunPoplarFunction([&] {
241+
return RunPoplarFunction<poplar::poplar_error>([&] {
242242
return popitCopyFromHost(
243243
static_cast<const char*>(host_src),
244244
static_cast<PopItSubBuffer*>(dst->opaque())->GetDevicePtr());
@@ -249,7 +249,7 @@ bool PopItExecutor::MemcpyDeviceToDevice(se::Stream* stream,
249249
se::DeviceMemoryBase* dst,
250250
const se::DeviceMemoryBase& src,
251251
uint64 size) {
252-
return RunPoplarFunction([&] {
252+
return RunPoplarFunction<poplar::poplar_error>([&] {
253253
return popitCopy(
254254
static_cast<const PopItSubBuffer*>(src.opaque())->GetDevicePtr(),
255255
static_cast<PopItSubBuffer*>(dst->opaque())->GetDevicePtr())
@@ -260,7 +260,7 @@ bool PopItExecutor::HostCallback(se::Stream* stream,
260260
std::function<void()> callback) {
261261
// For now sync and then callback, we should aim to make this async
262262
// though
263-
return RunPoplarFunction([&] {
263+
return RunPoplarFunction<poplar::poplar_error>([&] {
264264
popitSync(session_.get());
265265
callback();
266266
})

tensorflow/compiler/plugin/poplar/driver/popit_backend/popit_memory.h

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -65,34 +65,6 @@ struct PopItSubBuffer {
6565
PopItSubBuffer(popitMem_t* mem, int64_t size)
6666
: PopItSubBuffer(PopItBufferType(mem, PopItDeallocator()), 0, size) {}
6767
};
68-
69-
template <class T>
70-
using StatusType = typename std::conditional<std::is_same<T, void>::value,
71-
Status, StatusOr<T>>::type;
72-
73-
template <typename F, typename... Args>
74-
using DeducedReturn = StatusType<typename std::result_of<F(Args...)>::type>;
75-
76-
Status ConvertError(const std::exception& e) {
77-
return PoplarExceptionToTensorflowStatus("", e);
78-
}
79-
80-
// Function that runs a poplar function and converts any errors to
81-
// status/statusor<T>
82-
template <typename F, typename... Args>
83-
DeducedReturn<F, Args...> RunPoplarFunction(F f, Args&&... args) {
84-
try {
85-
if constexpr (std::is_same<DeducedReturn<F, Args...>, Status>::value) {
86-
f(std::forward<Args>(args)...);
87-
return Status::OK();
88-
} else {
89-
return f(std::forward<Args>(args)...);
90-
}
91-
} catch (const poplar::poplar_error& e) {
92-
return ConvertError(e);
93-
}
94-
}
95-
9668
} // namespace poplarplugin
9769
} // namespace xla
9870

tensorflow/compiler/plugin/poplar/driver/tools/poplar_util.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1042,5 +1042,9 @@ void CheckPoplarPackageHash() {
10421042
}
10431043
}
10441044
}
1045+
1046+
Status ConvertError(const std::exception& e) {
1047+
return PoplarExceptionToTensorflowStatus("", e);
1048+
}
10451049
} // namespace poplarplugin
10461050
} // namespace xla

tensorflow/compiler/plugin/poplar/driver/tools/poplar_util.h

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ limitations under the License.
1919
* These functions are related to poplar, and cannot be used within the
2020
* optimizers target in the BUILD file.
2121
*/
22+
#include <string>
23+
#include <utility>
24+
#include <vector>
25+
2226
#include <gcl/Collectives.hpp>
2327
#include <poplar/Program.hpp>
2428
#include <poplar/exceptions.hpp>
@@ -27,8 +31,6 @@ limitations under the License.
2731
#include <popnn/Pooling.hpp>
2832
#include <popops/Expr.hpp>
2933
#include <poputil/exceptions.hpp>
30-
#include <string>
31-
#include <vector>
3234

3335
#include "absl/container/inlined_vector.h"
3436
#include "absl/types/optional.h"
@@ -282,6 +284,31 @@ bool HasIOTiles(CompilerResources& res);
282284
int64_t GetNumIPUs(CompilerResources& res);
283285

284286
void CheckPoplarPackageHash();
287+
288+
template <class T>
289+
using StatusType = typename std::conditional<std::is_same<T, void>::value,
290+
Status, StatusOr<T>>::type;
291+
292+
template <typename F, typename... Args>
293+
using DeducedReturn = StatusType<typename std::result_of<F(Args...)>::type>;
294+
295+
Status ConvertError(const std::exception& e);
296+
297+
// Function that runs a poplar function and converts any errors to
298+
// status/statusor<T>
299+
template <typename E, typename F, typename... Args>
300+
DeducedReturn<F, Args...> RunPoplarFunction(F f, Args&&... args) {
301+
try {
302+
if constexpr (std::is_same<DeducedReturn<F, Args...>, Status>::value) {
303+
f(std::forward<Args>(args)...);
304+
return Status::OK();
305+
} else {
306+
return f(std::forward<Args>(args)...);
307+
}
308+
} catch (const E& e) {
309+
return ConvertError(e);
310+
}
311+
}
285312
} // namespace poplarplugin
286313
} // namespace xla
287314

tensorflow/compiler/plugin/poplar/kernels/popdist/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ poplar_cc_library(
1111
"all_reduce.cc",
1212
],
1313
deps = [
14+
"//tensorflow/compiler/plugin/poplar/driver/tools:poplar_util",
1415
"//tensorflow/core:framework",
1516
"//tensorflow/core:lib",
1617
"//tensorflow/core:lib_internal",
@@ -27,6 +28,7 @@ poplar_cc_library(
2728
"broadcast.cc",
2829
],
2930
deps = [
31+
"//tensorflow/compiler/plugin/poplar/driver/tools:poplar_util",
3032
"//tensorflow/core:framework",
3133
"//tensorflow/core:lib",
3234
"//tensorflow/core:lib_internal",

tensorflow/compiler/plugin/poplar/kernels/popdist/all_reduce.cc

Lines changed: 41 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
15+
#include <future>
16+
17+
#include "tensorflow/compiler/plugin/poplar/driver/tools/poplar_util.h"
1518
#include "tensorflow/core/framework/op_kernel.h"
1619
#include "tensorflow/core/framework/op_requires.h"
1720
#include "tensorflow/core/framework/register_types.h"
@@ -21,49 +24,71 @@ limitations under the License.
2124
#include <popdist/collectives.hpp>
2225
#include <popdist/context.hpp>
2326

27+
namespace poplar {
28+
template <>
29+
struct equivalent_device_type<Eigen::half> {
30+
const Type& value = HALF;
31+
};
32+
} // namespace poplar
33+
2434
namespace tensorflow {
2535
template <typename T>
26-
class PopDistAllReduceOp : public OpKernel {
36+
class PopDistAllReduceOp : public AsyncOpKernel {
2737
public:
28-
explicit PopDistAllReduceOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
38+
explicit PopDistAllReduceOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
2939
OP_REQUIRES_OK(ctx, ctx->GetAttr("reduce_op", &reduce_op_));
40+
OP_REQUIRES_OK(ctx, ctx->GetAttr("tensor_name", &tensor_name_));
3041
}
3142

32-
void Compute(OpKernelContext* ctx) override {
43+
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
3344
auto& input = ctx->input(0);
3445
Tensor* output;
3546

36-
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output));
37-
OP_REQUIRES_OK(ctx, tensorflow::functor::DoCopy(ctx->eigen_cpu_device(),
38-
input, output));
47+
OP_REQUIRES_OK_ASYNC(ctx, ctx->allocate_output(0, input.shape(), &output),
48+
done);
49+
OP_REQUIRES_OK_ASYNC(
50+
ctx,
51+
tensorflow::functor::DoCopy(ctx->eigen_cpu_device(), input, output),
52+
done);
3953

4054
auto* flattened_buffer = output->flat<T>().data();
4155

42-
popdist::collectives::sequential::allReduceSum(
43-
flattened_buffer, input.NumElements(),
44-
poplar::equivalent_device_type<T>().value);
56+
auto future = std::async(std::launch::async, [&] {
57+
OP_REQUIRES_OK_ASYNC(
58+
ctx,
59+
xla::poplarplugin::RunPoplarFunction<popdist::popdist_error>([&] {
60+
popdist::collectives::parallel::allReduceSum(
61+
flattened_buffer, input.NumElements(),
62+
poplar::equivalent_device_type<T>().value, this->tensor_name_);
4563

46-
const auto num_instances = popdist::getNumInstances();
64+
const auto num_instances = popdist::getNumInstances();
4765

48-
if (reduce_op_ == "MEAN") {
49-
for (auto i = 0; i < input.NumElements(); ++i) {
50-
*(flattened_buffer + i) /= num_instances;
51-
}
52-
}
66+
if (this->reduce_op_ == "MEAN") {
67+
for (auto i = 0; i < input.NumElements(); ++i) {
68+
*(flattened_buffer + i) /= static_cast<T>(num_instances);
69+
}
70+
}
71+
72+
done();
73+
}),
74+
done);
75+
});
5376
}
5477

5578
private:
5679
TF_DISALLOW_COPY_AND_ASSIGN(PopDistAllReduceOp);
5780

5881
std::string reduce_op_;
59-
};
82+
std::string tensor_name_;
83+
}; // namespace tensorflow
6084

6185
#define REGISTER_CPU(T) \
6286
REGISTER_KERNEL_BUILDER( \
6387
Name("PopdistAllReduce").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
6488
PopDistAllReduceOp<T>);
6589

6690
TF_CALL_INTEGRAL_TYPES(REGISTER_CPU);
91+
TF_CALL_half(REGISTER_CPU);
6792
TF_CALL_float(REGISTER_CPU);
6893
#undef REGISTER_CPU
6994
} // namespace tensorflow

tensorflow/compiler/plugin/poplar/kernels/popdist/broadcast.cc

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
15+
#include <future>
16+
17+
#include "tensorflow/compiler/plugin/poplar/driver/tools/poplar_util.h"
1518
#include "tensorflow/core/framework/op_kernel.h"
1619
#include "tensorflow/core/framework/op_requires.h"
1720
#include "tensorflow/core/framework/register_types.h"
@@ -21,34 +24,58 @@ limitations under the License.
2124
#include <popdist/collectives.hpp>
2225
#include <popdist/context.hpp>
2326

27+
namespace poplar {
28+
template <>
29+
struct equivalent_device_type<Eigen::half> {
30+
const Type& value = HALF;
31+
};
32+
} // namespace poplar
33+
2434
namespace tensorflow {
2535
template <typename T>
26-
class PopDistBroadcastOp : public OpKernel {
36+
class PopDistBroadcastOp : public AsyncOpKernel {
2737
public:
28-
explicit PopDistBroadcastOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
38+
explicit PopDistBroadcastOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
39+
OP_REQUIRES_OK(ctx, ctx->GetAttr("tensor_name", &tensor_name_));
40+
}
2941

30-
void Compute(OpKernelContext* ctx) override {
42+
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
3143
auto& input = ctx->input(0);
3244
Tensor* output;
3345

34-
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output));
35-
OP_REQUIRES_OK(ctx, tensorflow::functor::DoCopy(ctx->eigen_cpu_device(),
36-
input, output));
37-
popdist::collectives::sequential::broadcast(
38-
output->flat<T>().data(), output->NumElements(),
39-
poplar::equivalent_device_type<T>().value);
46+
OP_REQUIRES_OK_ASYNC(ctx, ctx->allocate_output(0, input.shape(), &output),
47+
done);
48+
OP_REQUIRES_OK_ASYNC(
49+
ctx,
50+
tensorflow::functor::DoCopy(ctx->eigen_cpu_device(), input, output),
51+
done);
52+
53+
auto future = std::async(std::launch::async, [&] {
54+
OP_REQUIRES_OK_ASYNC(
55+
ctx,
56+
xla::poplarplugin::RunPoplarFunction<popdist::popdist_error>([&] {
57+
popdist::collectives::parallel::broadcast(
58+
output->flat<T>().data(), output->NumElements(),
59+
poplar::equivalent_device_type<T>().value, this->tensor_name_);
60+
61+
done();
62+
}),
63+
done);
64+
});
4065
}
4166

4267
private:
4368
TF_DISALLOW_COPY_AND_ASSIGN(PopDistBroadcastOp);
44-
};
69+
std::string tensor_name_;
70+
}; // namespace tensorflow
4571

4672
#define REGISTER_CPU(T) \
4773
REGISTER_KERNEL_BUILDER( \
4874
Name("PopdistBroadcast").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
4975
PopDistBroadcastOp<T>);
5076

5177
TF_CALL_INTEGRAL_TYPES(REGISTER_CPU);
78+
TF_CALL_half(REGISTER_CPU);
5279
TF_CALL_float(REGISTER_CPU);
5380
#undef REGISTER_CPU
5481
} // namespace tensorflow

tensorflow/compiler/plugin/poplar/ops/popdist/all_reduce.cc

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,7 @@ REGISTER_OP("PopdistAllReduce")
2020
.Attr("T: numbertype")
2121
.Input("tensor: T")
2222
.Attr("reduce_op: string")
23+
.Attr("tensor_name: string")
2324
.Output("sum: T")
24-
.SetShapeFn([](shape_inference::InferenceContext* c) {
25-
shape_inference::ShapeHandle output;
26-
TF_RETURN_IF_ERROR(
27-
c->ReplaceDim(c->input(0), 0, c->UnknownDim(), &output));
28-
c->set_output(0, output);
29-
return Status::OK();
30-
});
25+
.SetShapeFn(shape_inference::UnchangedShape);
3126
} // namespace tensorflow

tensorflow/compiler/plugin/poplar/ops/popdist/broadcast.cc

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,7 @@ namespace tensorflow {
1919
REGISTER_OP("PopdistBroadcast")
2020
.Attr("T: numbertype")
2121
.Input("tensor: T")
22+
.Attr("tensor_name: string")
2223
.Output("sum: T")
23-
.SetShapeFn([](shape_inference::InferenceContext* c) {
24-
shape_inference::ShapeHandle output;
25-
26-
if (c->Rank(c->input(0)) > 0) {
27-
TF_RETURN_IF_ERROR(
28-
c->ReplaceDim(c->input(0), 0, c->UnknownDim(), &output));
29-
}
30-
31-
c->set_output(0, output);
32-
return Status::OK();
33-
});
24+
.SetShapeFn(shape_inference::UnchangedShape);
3425
} // namespace tensorflow

0 commit comments

Comments
 (0)