@@ -30,7 +30,17 @@ class PopopsMultiSliceOp : public XlaOpKernel, IpuOpKernel {
3030
3131 void Compile (XlaOpKernelContext* ctx) override {
3232 const TensorShape input_shape = ctx->InputShape (0 );
33- TensorShape output_shape = ctx->InputShape (1 );
33+ const TensorShape indices_shape = ctx->InputShape (1 );
34+
35+ OP_REQUIRES (ctx, TensorShapeUtils::IsMatrix (input_shape),
36+ errors::InvalidArgument (" input shape must be 2D, but got: " ,
37+ input_shape.DebugString ()));
38+
39+ OP_REQUIRES (ctx, TensorShapeUtils::IsVector (indices_shape),
40+ errors::InvalidArgument (" indices shape must be 1D, but got: " ,
41+ indices_shape.DebugString ()));
42+
43+ TensorShape output_shape = indices_shape;
3444 output_shape.AddDim (input_shape.dim_size (1 ));
3545
3646 xla::PrimitiveType input_type;
@@ -63,10 +73,18 @@ class PopopsMultiUpdateOp : public XlaOpKernel, IpuOpKernel {
6373 const TensorShape input_shape = ctx->InputShape (0 );
6474 const TensorShape indices_shape = ctx->InputShape (1 );
6575 const TensorShape updates_shape = ctx->InputShape (2 );
66- xla::PrimitiveType input_type;
67- OP_REQUIRES_OK (ctx,
68- DataTypeToPrimitiveType (ctx->input_type (0 ), &input_type));
6976
77+ OP_REQUIRES (ctx, TensorShapeUtils::IsMatrix (input_shape),
78+ errors::InvalidArgument (" input shape must be 2D, but got: " ,
79+ input_shape.DebugString ()));
80+
81+ OP_REQUIRES (ctx, TensorShapeUtils::IsVector (indices_shape),
82+ errors::InvalidArgument (" indices shape must be 1D, but got: " ,
83+ indices_shape.DebugString ()));
84+
85+ OP_REQUIRES (ctx, TensorShapeUtils::IsMatrix (updates_shape),
86+ errors::InvalidArgument (" updates shape must be 2D, but got: " ,
87+ updates_shape.DebugString ()));
7088 if (is_update_add_) {
7189 const TensorShape scale_shape = ctx->InputShape (3 );
7290 OP_REQUIRES (ctx, TensorShapeUtils::IsScalar (scale_shape),
@@ -75,6 +93,10 @@ class PopopsMultiUpdateOp : public XlaOpKernel, IpuOpKernel {
7593 }
7694
7795 xla::XlaBuilder& b = *ctx->builder ();
96+
97+ xla::PrimitiveType input_type;
98+ OP_REQUIRES_OK (ctx,
99+ DataTypeToPrimitiveType (ctx->input_type (0 ), &input_type));
78100 xla::Shape xla_output_shape =
79101 TensorShapeToXLAShape (input_type, input_shape);
80102 const auto num_inputs = ctx->num_inputs ();
0 commit comments