Skip to content

Commit 9ce6427

Browse files
committed
Force multislice/update to be 2D
Summary: Ref T45278 Test Plan: CI Reviewers: jakeh, georgew, jackh, #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved Reviewed By: georgew, #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved Maniphest Tasks: T45278 Differential Revision: https://phabricator.sourcevertex.net/D52785
1 parent 3ebb04a commit 9ce6427

File tree

3 files changed

+37
-16
lines changed

3 files changed

+37
-16
lines changed

tensorflow/compiler/plugin/poplar/kernels/popops/multi_slice.cc

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,17 @@ class PopopsMultiSliceOp : public XlaOpKernel, IpuOpKernel {
3030

3131
void Compile(XlaOpKernelContext* ctx) override {
3232
const TensorShape input_shape = ctx->InputShape(0);
33-
TensorShape output_shape = ctx->InputShape(1);
33+
const TensorShape indices_shape = ctx->InputShape(1);
34+
35+
OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(input_shape),
36+
errors::InvalidArgument("input shape must be 2D, but got: ",
37+
input_shape.DebugString()));
38+
39+
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(indices_shape),
40+
errors::InvalidArgument("indices shape must be 1D, but got: ",
41+
indices_shape.DebugString()));
42+
43+
TensorShape output_shape = indices_shape;
3444
output_shape.AddDim(input_shape.dim_size(1));
3545

3646
xla::PrimitiveType input_type;
@@ -63,10 +73,18 @@ class PopopsMultiUpdateOp : public XlaOpKernel, IpuOpKernel {
6373
const TensorShape input_shape = ctx->InputShape(0);
6474
const TensorShape indices_shape = ctx->InputShape(1);
6575
const TensorShape updates_shape = ctx->InputShape(2);
66-
xla::PrimitiveType input_type;
67-
OP_REQUIRES_OK(ctx,
68-
DataTypeToPrimitiveType(ctx->input_type(0), &input_type));
6976

77+
OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(input_shape),
78+
errors::InvalidArgument("input shape must be 2D, but got: ",
79+
input_shape.DebugString()));
80+
81+
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(indices_shape),
82+
errors::InvalidArgument("indices shape must be 1D, but got: ",
83+
indices_shape.DebugString()));
84+
85+
OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(updates_shape),
86+
errors::InvalidArgument("updates shape must be 2D, but got: ",
87+
updates_shape.DebugString()));
7088
if (is_update_add_) {
7189
const TensorShape scale_shape = ctx->InputShape(3);
7290
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(scale_shape),
@@ -75,6 +93,10 @@ class PopopsMultiUpdateOp : public XlaOpKernel, IpuOpKernel {
7593
}
7694

7795
xla::XlaBuilder& b = *ctx->builder();
96+
97+
xla::PrimitiveType input_type;
98+
OP_REQUIRES_OK(ctx,
99+
DataTypeToPrimitiveType(ctx->input_type(0), &input_type));
78100
xla::Shape xla_output_shape =
79101
TensorShapeToXLAShape(input_type, input_shape);
80102
const auto num_inputs = ctx->num_inputs();

tensorflow/python/ipu/tests/embedding_lookup_test.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -204,8 +204,8 @@ def network(x1, lr):
204204
def testGradient(self):
205205
with self.session() as sess:
206206
with ops.device('cpu'):
207-
x1 = array_ops.placeholder(np.int32, shape=[3, 4, 2])
208-
grads = array_ops.placeholder(np.float32, shape=[3, 4, 2, 16])
207+
x1 = array_ops.placeholder(np.int32, shape=[24])
208+
grads = array_ops.placeholder(np.float32, shape=[24, 16])
209209
lr = array_ops.placeholder(np.float32, shape=[])
210210

211211
def network(x1, grads, lr):
@@ -231,9 +231,10 @@ def network(x1, grads, lr):
231231
sess.run(variables.global_variables_initializer())
232232
out, indices, gradient = sess.run(
233233
r, {
234-
x1: [[[10, 11], [12, 13], [14, 15], [16, 17]],
235-
[[20, 21], [22, 23], [24, 25], [26, 27]],
236-
[[30, 31], [32, 33], [34, 35], [36, 37]]],
234+
x1: [
235+
10, 11, 12, 13, 14, 15, 16, 17, 20, 21, 22, 23, 24, 25, 26,
236+
27, 30, 31, 32, 33, 34, 35, 36, 37
237+
],
237238
grads:
238239
np.random.rand(*grads.shape),
239240
lr:

tensorflow/python/ipu/tests/replication_all_to_all_test.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def testAllGatherShapeInference(self):
7777
@test_util.deprecated_graph_mode_only
7878
def testSerializedMultiUpdateAdd(self):
7979
with ops.device('cpu'):
80-
idx = array_ops.placeholder(np.int32, shape=[16, 1])
80+
idx = array_ops.placeholder(np.int32, shape=[16])
8181
updates = array_ops.placeholder(np.float32, shape=[16, 128])
8282
scale = array_ops.placeholder(np.float32, shape=[])
8383

@@ -108,12 +108,10 @@ def my_graph(idx, updates, scale):
108108
sess.run(variables.global_variables_initializer())
109109
sess.run(
110110
out, {
111-
idx: [[1], [2], [3], [4], [1], [2], [3], [4], [10], [20], [30],
112-
[40], [100], [200], [300], [400]],
113-
updates:
114-
np.ones(updates.shape),
115-
scale:
116-
2,
111+
idx:
112+
[1, 2, 3, 4, 1, 2, 3, 4, 10, 20, 30, 40, 100, 200, 300, 400],
113+
updates: np.ones(updates.shape),
114+
scale: 2,
117115
})
118116
result = sess.run(outfeed)
119117

0 commit comments

Comments
 (0)