Skip to content

Commit e77e6d5

Browse files
caandewielFrederik Mellbye
authored andcommitted
Add PopDistAllGather to host_collective_ops
Reviewers: #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, jakeh Reviewed By: #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, jakeh Subscribers: jakeh Maniphest Tasks: T64481 Differential Revision: https://phabricator.sourcevertex.net/D74652
1 parent bda9700 commit e77e6d5

File tree

9 files changed

+254
-29
lines changed

9 files changed

+254
-29
lines changed

tensorflow/compiler/plugin/poplar/kernels/popdist/BUILD

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,23 @@ package(default_visibility = [
55
"//tensorflow/python/ipu:__subpackages__",
66
])
77

8+
poplar_cc_library(
9+
name = "all_gather",
10+
srcs = [
11+
"all_gather.cc",
12+
],
13+
deps = [
14+
"//tensorflow/compiler/plugin/poplar/driver/tools:poplar_util",
15+
"//tensorflow/core:framework",
16+
"//tensorflow/core:lib",
17+
"//tensorflow/core:lib_internal",
18+
"//tensorflow/core/kernels:inplace_ops",
19+
"//third_party/eigen3",
20+
"@local_config_poplar//poplar:poplar_libs",
21+
],
22+
alwayslink = True,
23+
)
24+
825
poplar_cc_library(
926
name = "all_reduce",
1027
srcs = [
@@ -42,6 +59,7 @@ poplar_cc_library(
4259
poplar_cc_library(
4360
name = "popdist",
4461
deps = [
62+
":all_gather",
4563
":all_reduce",
4664
":broadcast",
4765
],
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
#include "tensorflow/compiler/plugin/poplar/driver/tools/poplar_util.h"
16+
#include "tensorflow/core/framework/op_kernel.h"
17+
#include "tensorflow/core/framework/op_requires.h"
18+
#include "tensorflow/core/framework/register_types.h"
19+
#include "tensorflow/core/kernels/inplace_ops_functor.h"
20+
21+
#include <popdist/backend.hpp>
22+
#include <popdist/collectives.hpp>
23+
#include <popdist/context.hpp>
24+
25+
namespace poplar {
26+
template <>
27+
struct equivalent_device_type<Eigen::half> {
28+
const Type& value = HALF;
29+
};
30+
} // namespace poplar
31+
32+
namespace tensorflow {
33+
template <typename T>
34+
class PopDistAllGatherOp : public AsyncOpKernel {
35+
public:
36+
explicit PopDistAllGatherOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
37+
OP_REQUIRES_OK(ctx, ctx->GetAttr("tensor_name", &tensor_name_));
38+
}
39+
40+
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
41+
auto& input = ctx->input(0);
42+
Tensor* output;
43+
44+
auto output_shape = input.shape();
45+
output_shape.InsertDim(0, popdist::getNumInstances());
46+
47+
OP_REQUIRES_OK_ASYNC(ctx, ctx->allocate_output(0, output_shape, &output),
48+
done);
49+
50+
Env::Default()->SchedClosure([input = &input, output, ctx, done, this] {
51+
OP_REQUIRES_OK_ASYNC(
52+
ctx,
53+
xla::poplarplugin::RunPoplarFunction<popdist::popdist_error>(
54+
[&input, &output, &ctx, &done, this] {
55+
popdist::collectives::parallel::allGather(
56+
input->flat<T>().data(), output->flat<T>().data(),
57+
input->NumElements(),
58+
poplar::equivalent_device_type<T>().value,
59+
this->tensor_name_);
60+
61+
done();
62+
}),
63+
done);
64+
});
65+
}
66+
67+
private:
68+
TF_DISALLOW_COPY_AND_ASSIGN(PopDistAllGatherOp);
69+
std::string tensor_name_;
70+
}; // namespace tensorflow
71+
72+
#define REGISTER_CPU(T) \
73+
REGISTER_KERNEL_BUILDER( \
74+
Name("PopdistAllGather").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
75+
PopDistAllGatherOp<T>);
76+
77+
TF_CALL_INTEGRAL_TYPES(REGISTER_CPU);
78+
TF_CALL_half(REGISTER_CPU);
79+
TF_CALL_float(REGISTER_CPU);
80+
#undef REGISTER_CPU
81+
} // namespace tensorflow

tensorflow/compiler/plugin/poplar/kernels/popdist/all_reduce.cc

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
15-
#include <future>
16-
1715
#include "tensorflow/compiler/plugin/poplar/driver/tools/poplar_util.h"
1816
#include "tensorflow/core/framework/op_kernel.h"
1917
#include "tensorflow/core/framework/op_requires.h"
@@ -53,24 +51,28 @@ class PopDistAllReduceOp : public AsyncOpKernel {
5351

5452
auto* flattened_buffer = output->flat<T>().data();
5553

56-
auto future = std::async(std::launch::async, [&] {
54+
Env::Default()->SchedClosure([flattened_buffer, ctx, done, this] {
5755
OP_REQUIRES_OK_ASYNC(
5856
ctx,
59-
xla::poplarplugin::RunPoplarFunction<popdist::popdist_error>([&] {
60-
popdist::collectives::parallel::allReduceSum(
61-
flattened_buffer, input.NumElements(),
62-
poplar::equivalent_device_type<T>().value, this->tensor_name_);
57+
xla::poplarplugin::RunPoplarFunction<popdist::popdist_error>(
58+
[&flattened_buffer, &ctx, &done, this] {
59+
const auto num_elements = ctx->input(0).NumElements();
60+
61+
popdist::collectives::parallel::allReduceSum(
62+
flattened_buffer, num_elements,
63+
poplar::equivalent_device_type<T>().value,
64+
this->tensor_name_);
6365

64-
const auto num_instances = popdist::getNumInstances();
66+
const auto num_instances = popdist::getNumInstances();
6567

66-
if (this->reduce_op_ == "MEAN") {
67-
for (auto i = 0; i < input.NumElements(); ++i) {
68-
*(flattened_buffer + i) /= static_cast<T>(num_instances);
69-
}
70-
}
68+
if (this->reduce_op_ == "MEAN") {
69+
for (auto i = 0; i < num_elements; ++i) {
70+
*(flattened_buffer + i) /= static_cast<T>(num_instances);
71+
}
72+
}
7173

72-
done();
73-
}),
74+
done();
75+
}),
7476
done);
7577
});
7678
}

tensorflow/compiler/plugin/poplar/kernels/popdist/broadcast.cc

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
15-
#include <future>
16-
1715
#include "tensorflow/compiler/plugin/poplar/driver/tools/poplar_util.h"
1816
#include "tensorflow/core/framework/op_kernel.h"
1917
#include "tensorflow/core/framework/op_requires.h"
@@ -50,16 +48,18 @@ class PopDistBroadcastOp : public AsyncOpKernel {
5048
tensorflow::functor::DoCopy(ctx->eigen_cpu_device(), input, output),
5149
done);
5250

53-
auto future = std::async(std::launch::async, [&] {
51+
Env::Default()->SchedClosure([output, ctx, done, this] {
5452
OP_REQUIRES_OK_ASYNC(
5553
ctx,
56-
xla::poplarplugin::RunPoplarFunction<popdist::popdist_error>([&] {
57-
popdist::collectives::parallel::broadcast(
58-
output->flat<T>().data(), output->NumElements(),
59-
poplar::equivalent_device_type<T>().value, this->tensor_name_);
54+
xla::poplarplugin::RunPoplarFunction<popdist::popdist_error>(
55+
[&output, &ctx, &done, this] {
56+
popdist::collectives::parallel::broadcast(
57+
output->flat<T>().data(), output->NumElements(),
58+
poplar::equivalent_device_type<T>().value,
59+
this->tensor_name_);
6060

61-
done();
62-
}),
61+
done();
62+
}),
6363
done);
6464
});
6565
}

tensorflow/compiler/plugin/poplar/ops/BUILD

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -312,8 +312,7 @@ poplar_cc_library(
312312
"//tensorflow/compiler/plugin/poplar/ops/datastream:host_embedding",
313313
"//tensorflow/compiler/plugin/poplar/ops/functional",
314314
"//tensorflow/compiler/plugin/poplar/ops/functional:pipelining",
315-
"//tensorflow/compiler/plugin/poplar/ops/popdist:all_reduce",
316-
"//tensorflow/compiler/plugin/poplar/ops/popdist:broadcast",
315+
"//tensorflow/compiler/plugin/poplar/ops/popdist:ops",
317316
"//tensorflow/compiler/plugin/poplar/ops/popfloat:cast_to_gfloat",
318317
"//tensorflow/compiler/plugin/poplar/ops/popnn:ctc_loss",
319318
"//tensorflow/compiler/plugin/poplar/ops/popnn:gelu",

tensorflow/compiler/plugin/poplar/ops/popdist/BUILD

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,17 @@ licenses(["restricted"])
44

55
package(default_visibility = ["//tensorflow/compiler/plugin/poplar:__subpackages__"])
66

7+
poplar_cc_library(
8+
name = "all_gather",
9+
srcs = [
10+
"all_gather.cc",
11+
],
12+
deps = [
13+
"//tensorflow/core:framework",
14+
],
15+
alwayslink = True,
16+
)
17+
718
poplar_cc_library(
819
name = "all_reduce",
920
srcs = [
@@ -30,6 +41,7 @@ poplar_cc_library(
3041
name = "ops",
3142
srcs = [],
3243
deps = [
44+
":all_gather",
3345
":all_reduce",
3446
":broadcast",
3547
],
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
#include "tensorflow/core/framework/common_shape_fns.h"
16+
#include "tensorflow/core/framework/op.h"
17+
18+
namespace tensorflow {
19+
REGISTER_OP("PopdistAllGather")
20+
.Attr("T: numbertype")
21+
.Input("tensor: T")
22+
.Attr("tensor_name: string")
23+
.Output("sum: T")
24+
.SetShapeFn(shape_inference::UnknownShape);
25+
} // namespace tensorflow

tensorflow/python/ipu/distributed/host_collective_ops.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,13 @@ def _normalize_name(name):
2323
return re.sub('[^a-zA-Z0-9_]', '_', name)
2424

2525

26-
def broadcast(value, root_rank=0, tensor_name=None):
26+
def all_gather(value, tensor_name=None):
2727
if not tensor_name and not context.executing_eagerly():
28-
tensor_name = "PopDistBroadcast_{}".format(_normalize_name(value.name))
28+
tensor_name = "PopDistAllGather_{}".format(_normalize_name(value.name))
2929
else:
3030
tensor_name = "Default"
3131

32-
return gen_popdist_ops.popdist_broadcast(value, tensor_name=tensor_name)
32+
return gen_popdist_ops.popdist_all_gather(value, tensor_name=tensor_name)
3333

3434

3535
def all_reduce(value, reduce_op, tensor_name=None):
@@ -41,3 +41,12 @@ def all_reduce(value, reduce_op, tensor_name=None):
4141
return gen_popdist_ops.popdist_all_reduce(value,
4242
reduce_op=reduce_op.value,
4343
tensor_name=tensor_name)
44+
45+
46+
def broadcast(value, root_rank=0, tensor_name=None):
47+
if not tensor_name and not context.executing_eagerly():
48+
tensor_name = "PopDistBroadcast_{}".format(_normalize_name(value.name))
49+
else:
50+
tensor_name = "Default"
51+
52+
return gen_popdist_ops.popdist_broadcast(value, tensor_name=tensor_name)

tensorflow/python/ipu/distributed/host_collective_ops_test.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,14 @@ class HostCollectiveOpsTest(test_util.TensorFlowTestCase,
3939
def setUpClass(cls):
4040
popdist.init()
4141

42+
@parameterized.named_parameters(*TESTCASES)
43+
def test_all_gather(self, dtype):
44+
x = constant_op.constant(popdist.getInstanceIndex(), dtype=dtype)
45+
self.assertAllEqual(
46+
host_collective_ops.all_gather(x),
47+
np.array([i for i in range(popdist.getNumInstances())],
48+
dtype=dtype.as_numpy_dtype))
49+
4250
@parameterized.named_parameters(*TESTCASES)
4351
def test_all_reduce_sum(self, dtype):
4452
x = constant_op.constant(popdist.getInstanceIndex(), dtype=dtype)
@@ -60,6 +68,44 @@ def test_broadcast(self, dtype):
6068
dtype=dtype)
6169
self.assertAllEqual(host_collective_ops.broadcast(x), 42)
6270

71+
def test_all_all_gather_different_order(self):
72+
# Call collective on `x` first and `y` afterwards.
73+
@def_function.function()
74+
def body_instance_even(x, y):
75+
res_x = host_collective_ops.all_gather(x)
76+
res_y = host_collective_ops.all_gather(y)
77+
78+
return (res_x, res_y)
79+
80+
# Call collective on `y` first and `x` afterwards.
81+
@def_function.function()
82+
def body_instance_odd(x, y):
83+
res_y = host_collective_ops.all_gather(y)
84+
res_x = host_collective_ops.all_gather(x)
85+
86+
return (res_x, res_y)
87+
88+
x = constant_op.constant(popdist.getInstanceIndex(), dtype=dtypes.float32)
89+
y = constant_op.constant(
90+
[popdist.getInstanceIndex(),
91+
popdist.getInstanceIndex()],
92+
dtype=dtypes.int32)
93+
94+
is_even = popdist.getInstanceIndex() % 2 == 0
95+
96+
# Test that we can call collectives in any order as long as our tensors have names.
97+
(res_x,
98+
res_y) = body_instance_even(x, y) if is_even else body_instance_odd(x, y)
99+
100+
self.assertAllEqual(
101+
res_x,
102+
np.array([i for i in range(popdist.getNumInstances())],
103+
dtype=np.float32))
104+
self.assertAllEqual(
105+
res_y,
106+
np.array([[i, i] for i in range(popdist.getNumInstances())],
107+
dtype=np.float32))
108+
63109
def test_all_reduce_different_order(self):
64110
# Call collective on `x` first and `y` afterwards.
65111
@def_function.function()
@@ -124,6 +170,39 @@ def body_instance_odd(x, y):
124170
self.assertAllEqual(res_x, 42)
125171
self.assertAllEqual(res_y, [42, 42])
126172

173+
def test_all_gather_different_dtype(self):
174+
dtype = dtypes.float32 if popdist.getInstanceIndex(
175+
) % 2 == 0 else dtypes.int32
176+
x = constant_op.constant(popdist.getInstanceIndex(), dtype=dtype)
177+
178+
try:
179+
host_collective_ops.all_gather(x)
180+
except errors.UnknownError as e:
181+
self.assertAllEqual(
182+
True,
183+
"Tensor layouts did not match on all instances" in e.message,
184+
)
185+
186+
return
187+
188+
self.fail()
189+
190+
def test_all_gather_different_shape(self):
191+
value = 1 if popdist.getInstanceIndex() % 2 == 0 else [1, 1]
192+
x = constant_op.constant(value, dtype=dtypes.int32)
193+
194+
try:
195+
host_collective_ops.all_gather(x)
196+
except errors.UnknownError as e:
197+
self.assertAllEqual(
198+
True,
199+
"Tensor layouts did not match on all instances" in e.message,
200+
)
201+
202+
return
203+
204+
self.fail()
205+
127206
def test_all_reduce_different_dtype(self):
128207
dtype = dtypes.float32 if popdist.getInstanceIndex(
129208
) % 2 == 0 else dtypes.int32

0 commit comments

Comments
 (0)