Skip to content

Commit 677fc69

Browse files
galexitegeorgepaw
authored andcommitted
Added indices_are_sorted to Multi{Slice,Update}
Summary: This commit adds the `indices_are_sorted` attribute to MultiSlice, MultiUpdate and MultiUpdateAdd ops. Reviewers: #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, georgep Reviewed By: #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, georgep Subscribers: georgep Maniphest Tasks: T46378 Differential Revision: https://phabricator.sourcevertex.net/D52556
1 parent 9ce6427 commit 677fc69

File tree

13 files changed

+149
-99
lines changed

13 files changed

+149
-99
lines changed

tensorflow/compiler/plugin/poplar/driver/tools/custom_ops/multi_slice.cc

Lines changed: 44 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,13 @@ limitations under the License.
1414
==============================================================================*/
1515

1616
#include "tensorflow/compiler/plugin/poplar/driver/tools/custom_ops/multi_slice.h"
17+
18+
#include <string>
19+
1720
#include "tensorflow/compiler/plugin/poplar/driver/tools/hlo_poplar_buffer_util.h"
1821
#include "tensorflow/compiler/plugin/poplar/driver/tools/matcher_predicates.h"
1922
#include "tensorflow/compiler/plugin/poplar/kernels/custom_kernels_util.h"
2023
#include "tensorflow/compiler/plugin/poplar/kernels/ops.pb.h"
21-
2224
#include "tensorflow/compiler/xla/shape_util.h"
2325

2426
namespace xla {
@@ -27,8 +29,10 @@ namespace poplarplugin {
2729
// MultiSlice
2830
HloMultiSliceInstruction::HloMultiSliceInstruction(
2931
const Shape& shape, HloInstruction* const input,
30-
HloInstruction* const indices)
31-
: HloPoplarInstruction(shape, {input, indices}, PoplarOp::MultiSlice) {}
32+
HloInstruction* const indices, bool indices_are_sorted)
33+
: HloPoplarInstruction(shape, {input, indices}, PoplarOp::MultiSlice,
34+
indices_are_sorted),
35+
indices_are_sorted_(indices_are_sorted) {}
3236

3337
absl::flat_hash_set<int64> HloMultiSliceInstruction::AllocatingIndices() const {
3438
return {0, 1};
@@ -61,7 +65,8 @@ std::unique_ptr<HloInstruction>
6165
HloMultiSliceInstruction::CloneWithNewOperandsImpl(
6266
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
6367
HloCloneContext*) const {
64-
return CreateMultiSlice(shape, new_operands[0], new_operands[1]);
68+
return CreateMultiSlice(shape, new_operands[0], new_operands[1],
69+
indices_are_sorted_);
6570
}
6671

6772
std::vector<std::string>
@@ -70,24 +75,28 @@ HloMultiSliceInstruction::ExtraPoplarAttributesToStringImpl(
7075
return {};
7176
}
7277

73-
std::unique_ptr<HloInstruction> CreateMultiSlice(
74-
const Shape& shape, HloInstruction* const input,
75-
HloInstruction* const indices) {
76-
return absl::make_unique<HloMultiSliceInstruction>(shape, input, indices);
78+
std::unique_ptr<HloInstruction> CreateMultiSlice(const Shape& shape,
79+
HloInstruction* const input,
80+
HloInstruction* const indices,
81+
bool indices_are_sorted) {
82+
return absl::make_unique<HloMultiSliceInstruction>(shape, input, indices,
83+
indices_are_sorted);
7784
}
7885

7986
// MultiUpdate
8087
HloMultiUpdateInstruction::HloMultiUpdateInstruction(
8188
const Shape& shape, absl::Span<HloInstruction* const> operands,
8289
std::size_t index_vector_dim, std::size_t update_dim,
83-
uint32 serialization_factor, bool is_update)
90+
uint32 serialization_factor, bool is_update, bool indices_are_sorted)
8491
: HloPoplarInstruction(
8592
shape, operands,
8693
is_update ? PoplarOp::MultiUpdateAdd : PoplarOp::MultiUpdate,
87-
index_vector_dim, update_dim, serialization_factor),
94+
index_vector_dim, update_dim, serialization_factor,
95+
indices_are_sorted),
8896
index_vector_dim_(index_vector_dim),
8997
update_dim_(update_dim),
90-
serialization_factor_(serialization_factor) {}
98+
serialization_factor_(serialization_factor),
99+
indices_are_sorted_(indices_are_sorted) {}
91100

92101
absl::flat_hash_set<int64> HloMultiUpdateInstruction::AllocatingIndices()
93102
const {
@@ -138,7 +147,7 @@ HloMultiUpdateInstruction::CloneWithNewOperandsImpl(
138147
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
139148
HloCloneContext*) const {
140149
return CreateMultiUpdate(shape, new_operands, index_vector_dim_, update_dim_,
141-
serialization_factor_);
150+
serialization_factor_, indices_are_sorted_);
142151
}
143152

144153
std::vector<std::string>
@@ -149,46 +158,55 @@ HloMultiUpdateInstruction::ExtraPoplarAttributesToStringImpl(
149158
attributes.push_back("update_dim=" + std::to_string(update_dim_));
150159
attributes.push_back("serialization_factor=" +
151160
std::to_string(serialization_factor_));
161+
attributes.push_back("indices_are_sorted=" +
162+
std::string(indices_are_sorted_ ? "true" : "false"));
152163
return attributes;
153164
}
154165

155166
std::unique_ptr<HloInstruction> CreateMultiUpdate(
156167
const Shape& shape, absl::Span<HloInstruction* const> operands,
157168
std::size_t index_vector_dim, std::size_t update_dim,
158-
uint32 serialization_factor) {
169+
uint32 serialization_factor, bool indices_are_sorted) {
159170
return absl::make_unique<HloMultiUpdateInstruction>(
160-
shape, operands, index_vector_dim, update_dim, serialization_factor);
171+
shape, operands, index_vector_dim, update_dim, serialization_factor,
172+
indices_are_sorted);
161173
}
162174

163175
// MultiUpdateAdd
164176
HloMultiUpdateAddInstruction::HloMultiUpdateAddInstruction(
165177
const Shape& shape, absl::Span<HloInstruction* const> operands,
166178
std::size_t index_vector_dim, std::size_t update_dim,
167-
uint32 serialization_factor)
179+
uint32 serialization_factor, bool indices_are_sorted)
168180
: HloMultiUpdateInstruction(shape, operands, index_vector_dim, update_dim,
169-
serialization_factor, true) {}
181+
serialization_factor, true,
182+
indices_are_sorted) {}
170183

171184
std::unique_ptr<HloInstruction>
172185
HloMultiUpdateAddInstruction::CloneWithNewOperandsImpl(
173186
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
174187
HloCloneContext*) const {
175188
return CreateMultiUpdateAdd(shape, new_operands, index_vector_dim_,
176-
update_dim_, serialization_factor_);
189+
update_dim_, serialization_factor_,
190+
indices_are_sorted_);
177191
}
178192

179193
std::unique_ptr<HloInstruction> CreateMultiUpdateAdd(
180194
const Shape& shape, absl::Span<HloInstruction* const> operands,
181195
std::size_t index_vector_dim, std::size_t update_dim,
182-
uint32 serialization_factor) {
196+
uint32 serialization_factor, bool indices_are_sorted) {
183197
return absl::make_unique<HloMultiUpdateAddInstruction>(
184-
shape, operands, index_vector_dim, update_dim, serialization_factor);
198+
shape, operands, index_vector_dim, update_dim, serialization_factor,
199+
indices_are_sorted);
185200
}
186201

187202
namespace {
188203
StatusOr<std::unique_ptr<HloInstruction>> HloMultiSliceInstructionFactoryFunc(
189204
HloCustomCallInstruction* call) {
205+
auto attribute_map = IPUCustomKernelsUtil::AttributeMap(call);
206+
TF_ASSIGN_OR_RETURN(bool indices_are_sorted,
207+
attribute_map.GetAttributeAsBool("indices_are_sorted"));
190208
return CreateMultiSlice(call->shape(), call->mutable_operand(0),
191-
call->mutable_operand(1));
209+
call->mutable_operand(1), indices_are_sorted);
192210
}
193211

194212
StatusOr<std::unique_ptr<HloInstruction>> HloMultiUpdateInstructionFactoryFunc(
@@ -198,8 +216,10 @@ StatusOr<std::unique_ptr<HloInstruction>> HloMultiUpdateInstructionFactoryFunc(
198216
attribute_map.GetAttributeAsUInt64("index_vector_dim"));
199217
TF_ASSIGN_OR_RETURN(uint64 update_dim,
200218
attribute_map.GetAttributeAsUInt64("update_dim"));
219+
TF_ASSIGN_OR_RETURN(bool indices_are_sorted,
220+
attribute_map.GetAttributeAsBool("indices_are_sorted"));
201221
return CreateMultiUpdate(call->shape(), call->operands(), index_vector_dim,
202-
update_dim, 1);
222+
update_dim, 1, indices_are_sorted);
203223
}
204224

205225
StatusOr<std::unique_ptr<HloInstruction>>
@@ -209,8 +229,10 @@ HloMultiUpdateAddInstructionFactoryFunc(HloCustomCallInstruction* call) {
209229
attribute_map.GetAttributeAsUInt64("index_vector_dim"));
210230
TF_ASSIGN_OR_RETURN(uint64 update_dim,
211231
attribute_map.GetAttributeAsUInt64("update_dim"));
232+
TF_ASSIGN_OR_RETURN(bool indices_are_sorted,
233+
attribute_map.GetAttributeAsBool("indices_are_sorted"));
212234
return CreateMultiUpdateAdd(call->shape(), call->operands(), index_vector_dim,
213-
update_dim, 1);
235+
update_dim, 1, indices_are_sorted);
214236
}
215237

216238
static HloPoplarInstructionFactory multi_slice_factory(

tensorflow/compiler/plugin/poplar/driver/tools/custom_ops/multi_slice.h

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ class HloMultiSliceInstruction : public HloPoplarInstruction {
2525
public:
2626
explicit HloMultiSliceInstruction(const Shape& shape,
2727
HloInstruction* const input,
28-
HloInstruction* const indices);
28+
HloInstruction* const indices,
29+
bool indices_are_sorted = false);
2930

3031
absl::flat_hash_set<int64> AllocatingIndices() const override;
3132
bool AllocatingOutput() const override;
@@ -39,10 +40,15 @@ class HloMultiSliceInstruction : public HloPoplarInstruction {
3940

4041
bool IsPopOpsElementwise() const override;
4142

43+
// Whether or not the given indices are sorted.
44+
bool GetIndicesAreSorted() const { return indices_are_sorted_; }
45+
4246
protected:
4347
std::vector<std::string> ExtraPoplarAttributesToStringImpl(
4448
const HloPrintOptions& options) const override;
4549

50+
const bool indices_are_sorted_;
51+
4652
private:
4753
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
4854
const Shape& shape, absl::Span<HloInstruction* const>,
@@ -56,7 +62,8 @@ class HloMultiUpdateInstruction : public HloPoplarInstruction {
5662
std::size_t index_vector_dim,
5763
std::size_t update_dim,
5864
uint32 serialization_factor,
59-
bool is_update_add = false);
65+
bool is_update_add = false,
66+
bool indices_are_sorted = false);
6067

6168
absl::flat_hash_set<int64> AllocatingIndices() const override;
6269
bool AllocatingOutput() const override;
@@ -79,13 +86,17 @@ class HloMultiUpdateInstruction : public HloPoplarInstruction {
7986
// Factor used for serializing the multi update.
8087
std::size_t GetSerializationFactor() const { return serialization_factor_; }
8188

89+
// Whether or not the given indices are sorted.
90+
bool GetIndicesAreSorted() const { return indices_are_sorted_; }
91+
8292
protected:
8393
std::vector<std::string> ExtraPoplarAttributesToStringImpl(
8494
const HloPrintOptions& options) const override;
8595

8696
const std::size_t index_vector_dim_;
8797
const std::size_t update_dim_;
8898
const uint32 serialization_factor_;
99+
const bool indices_are_sorted_;
89100

90101
private:
91102
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
@@ -98,27 +109,27 @@ class HloMultiUpdateAddInstruction : public HloMultiUpdateInstruction {
98109
explicit HloMultiUpdateAddInstruction(
99110
const Shape& shape, absl::Span<HloInstruction* const> operands,
100111
std::size_t index_vector_dim, std::size_t update_dim,
101-
uint32 serialization_factor);
112+
uint32 serialization_factor, bool indices_are_sorted);
102113

103114
private:
104115
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
105116
const Shape& shape, absl::Span<HloInstruction* const>,
106117
HloCloneContext*) const override;
107118
};
108119

109-
std::unique_ptr<HloInstruction> CreateMultiSlice(const Shape& shape,
110-
HloInstruction* const input,
111-
HloInstruction* const indices);
120+
std::unique_ptr<HloInstruction> CreateMultiSlice(
121+
const Shape& shape, HloInstruction* const input,
122+
HloInstruction* const indices, bool indices_are_sorted = false);
112123

113124
std::unique_ptr<HloInstruction> CreateMultiUpdate(
114125
const Shape& shape, absl::Span<HloInstruction* const> operands,
115126
std::size_t index_vector_dim, std::size_t update_dim,
116-
uint32 serialization_factor = 1);
127+
uint32 serialization_factor = 1, bool indices_are_sorted = false);
117128

118129
std::unique_ptr<HloInstruction> CreateMultiUpdateAdd(
119130
const Shape& shape, absl::Span<HloInstruction* const> operands,
120131
std::size_t index_vector_dim, std::size_t update_dim,
121-
uint32 serialization_factor = 1);
132+
uint32 serialization_factor = 1, bool indices_are_sorted = false);
122133

123134
} // namespace poplarplugin
124135
} // namespace xla

tensorflow/compiler/plugin/poplar/kernels/popops/multi_slice.cc

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,14 @@ class PopopsMultiSliceOp : public XlaOpKernel, IpuOpKernel {
2828
public:
2929
using XlaOpKernel::XlaOpKernel;
3030

31+
explicit PopopsMultiSliceOp(OpKernelConstruction* ctx)
32+
: XlaOpKernel(ctx), IpuOpKernel() {
33+
bool indices_are_sorted;
34+
OP_REQUIRES_OK(ctx,
35+
ctx->GetAttr("indices_are_sorted", &indices_are_sorted));
36+
attribute_map_.AddAttribute("indices_are_sorted", indices_are_sorted);
37+
}
38+
3139
void Compile(XlaOpKernelContext* ctx) override {
3240
const TensorShape input_shape = ctx->InputShape(0);
3341
const TensorShape indices_shape = ctx->InputShape(1);
@@ -67,7 +75,12 @@ REGISTER_IPU_OP("IpuMultiSlice", PopopsMultiSliceOp);
6775
class PopopsMultiUpdateOp : public XlaOpKernel, IpuOpKernel {
6876
public:
6977
PopopsMultiUpdateOp(OpKernelConstruction* ctx, bool is_update_add = false)
70-
: XlaOpKernel(ctx), IpuOpKernel(), is_update_add_(is_update_add) {}
78+
: XlaOpKernel(ctx), IpuOpKernel(), is_update_add_(is_update_add) {
79+
bool indices_are_sorted;
80+
OP_REQUIRES_OK(ctx,
81+
ctx->GetAttr("indices_are_sorted", &indices_are_sorted));
82+
attribute_map_.AddAttribute("indices_are_sorted", indices_are_sorted);
83+
}
7184

7285
void Compile(XlaOpKernelContext* ctx) override {
7386
const TensorShape input_shape = ctx->InputShape(0);

tensorflow/compiler/plugin/poplar/ops/popops/multi_slice.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ REGISTER_OP("IpuMultiSlice")
2323
.Input("indices: int32")
2424
.Output("output: dtype")
2525
.Attr("dtype: {float16, float32, int32}")
26+
.Attr("indices_are_sorted: bool = false")
2627
.SetShapeFn([](shape_inference::InferenceContext* c) {
2728
// outshape = list(ids.shape) + [N]
2829
shape_inference::ShapeHandle output, N, in_shape, indices;
@@ -42,6 +43,7 @@ REGISTER_OP("IpuMultiUpdate")
4243
.Input("updates: dtype")
4344
.Output("output: dtype")
4445
.Attr("dtype: {float16, float32, int32}")
46+
.Attr("indices_are_sorted: bool = false")
4547
.SetShapeFn(shape_inference::UnchangedShape);
4648

4749
REGISTER_OP("IpuMultiUpdateAdd")
@@ -51,6 +53,7 @@ REGISTER_OP("IpuMultiUpdateAdd")
5153
.Input("scale: dtype")
5254
.Output("output: dtype")
5355
.Attr("dtype: {float16, float32, int32}")
56+
.Attr("indices_are_sorted: bool = false")
5457
.SetShapeFn(shape_inference::UnchangedShape);
5558

5659
} // namespace tensorflow

tensorflow/compiler/plugin/poplar/tests/all_to_all_finder_test.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ ENTRY main {
4949
updates = f32[24,16] parameter(1)
5050
zero = f32[] constant(0)
5151
big_zero = f32[1000,16] broadcast(zero), dimensions={}
52-
operand = f32[1000,16] custom-call(big_zero, offsets, updates), custom_call_target="MultiUpdate", backend_config="{\"index_vector_dim\":1,\"update_dim\":1}\n"
52+
operand = f32[1000,16] custom-call(big_zero, offsets, updates), custom_call_target="MultiUpdate", backend_config="{\"index_vector_dim\":1,\"update_dim\":1,\"indices_are_sorted\":false}\n"
5353
operand_all = f32[1000,16] all-reduce(operand), to_apply=add
5454
ROOT operand_norm = f32[1000,16] custom-call(operand_all), custom_call_target="ReplicationNormalise", backend_config="{}\n"
5555
}
@@ -103,7 +103,7 @@ ENTRY main {
103103
updates = f32[24,16] parameter(1)
104104
zero = f32[] constant(0)
105105
big_zero = f32[1000,16] broadcast(zero), dimensions={}
106-
operand = f32[1000,16] custom-call(big_zero, offsets, updates), custom_call_target="MultiUpdate", backend_config="{\"index_vector_dim\":1,\"update_dim\":1}\n"
106+
operand = f32[1000,16] custom-call(big_zero, offsets, updates), custom_call_target="MultiUpdate", backend_config="{\"index_vector_dim\":1,\"update_dim\":1,\"indices_are_sorted\":false}\n"
107107
operand_all = f32[1000,16] all-reduce(operand), to_apply=add
108108
ROOT operand_norm = f32[1000,16] custom-call(operand_all), custom_call_target="ReplicationNormalise", backend_config="{}\n"
109109
}
@@ -158,7 +158,7 @@ ENTRY main {
158158
zero = f32[] constant(0)
159159
scale = f32[] constant(1)
160160
big_zero = f32[1000,16] broadcast(zero), dimensions={}
161-
operand = f32[1000,16] custom-call(big_zero, offsets, updates, scale), custom_call_target="MultiUpdateAdd", backend_config="{\"index_vector_dim\":1,\"update_dim\":1}\n"
161+
operand = f32[1000,16] custom-call(big_zero, offsets, updates, scale), custom_call_target="MultiUpdateAdd", backend_config="{\"index_vector_dim\":1,\"update_dim\":1,\"indices_are_sorted\":false}\n"
162162
operand_all = f32[1000,16] all-reduce(operand), to_apply=add
163163
ROOT operand_norm = f32[1000,16] custom-call(operand_all), custom_call_target="ReplicationNormalise", backend_config="{}\n"
164164
}
@@ -215,7 +215,7 @@ ENTRY main {
215215
zero = f32[] constant(0)
216216
scale = f32[] constant(1)
217217
big_zero = f32[1000,16] broadcast(zero), dimensions={}
218-
operand = f32[1000,16] custom-call(big_zero, offsets, updates, scale), custom_call_target="MultiUpdateAdd", backend_config="{\"index_vector_dim\":1,\"update_dim\":1}\n"
218+
operand = f32[1000,16] custom-call(big_zero, offsets, updates, scale), custom_call_target="MultiUpdate", backend_config="{\"index_vector_dim\":1,\"update_dim\":1,\"indices_are_sorted\":false}\n"
219219
operand_all = f32[1000,16] all-reduce(operand), to_apply=add
220220
ROOT operand_norm = f32[1000,16] custom-call(operand_all), custom_call_target="ReplicationNormalise", backend_config="{}\n"
221221
}

tensorflow/compiler/plugin/poplar/tests/allocation_finder_test.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ ENTRY c1 {
204204
p0 = f16[2, 64] parameter(0)
205205
p1 = f16[1024, 64] parameter(1)
206206
p2 = s32[2] parameter(2)
207-
slice = f16[2, 64] custom-call(p1, p2), custom_call_target="MultiSlice"
207+
slice = f16[2, 64] custom-call(p1, p2), custom_call_target="MultiSlice", backend_config="{\"indices_are_sorted\":false}"
208208
p1_t = f16[64, 1024] transpose(p1), dimensions={1, 0}
209209
dot = f16[2, 1024] dot(p0, p1_t), lhs_contracting_dims={1}, rhs_contracting_dims={0}
210210
ROOT t = (f16[2, 64], f16[2, 1024]) tuple(slice, dot)
@@ -269,7 +269,7 @@ ENTRY c1 {
269269
p1 = f16[64, 1024] parameter(1)
270270
p2 = s32[2] parameter(2)
271271
p1_t = f16[1024, 64] transpose(p1), dimensions={1, 0}
272-
slice = f16[2, 64] custom-call(p1_t, p2), custom_call_target="MultiSlice"
272+
slice = f16[2, 64] custom-call(p1_t, p2), custom_call_target="MultiSlice", backend_config="{\"indices_are_sorted\":false}"
273273
dot = f16[2, 1024] dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
274274
ROOT t = (f16[2, 64], f16[2, 1024]) tuple(slice, dot)
275275
}
@@ -332,7 +332,7 @@ ENTRY c1 {
332332
p1 = f16[64, 1024] parameter(1)
333333
p2 = s32[2] parameter(2)
334334
p1_t = f16[1024, 64] transpose(p1), dimensions={1, 0}
335-
mu = f16[1024, 64] custom-call(p1_t, p2, p0), custom_call_target="MultiUpdate", backend_config="{\"index_vector_dim\":1,\"update_dim\":1}\n"
335+
mu = f16[1024, 64] custom-call(p1_t, p2, p0), custom_call_target="MultiUpdate", backend_config="{\"index_vector_dim\":1,\"update_dim\":1,\"indices_are_sorted\":false}\n"
336336
dot = f16[2, 1024] dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
337337
ROOT t = (f16[2, 64], f16[2, 1024]) tuple(mu, dot)
338338
}
@@ -4448,8 +4448,8 @@ ENTRY c1 {
44484448
p1 = f16[64, 1024] parameter(1)
44494449
p2 = s32[2] parameter(2)
44504450
p1_t = f16[1024, 64] transpose(p1), dimensions={1, 0}
4451-
mu1 = f16[1024, 64] custom-call(p1_t, p2, p0), custom_call_target="MultiUpdateAdd", backend_config="{\"index_vector_dim\":1,\"update_dim\":1}\n"
4452-
mu2 = f16[1024, 64] custom-call(mu1, p2, p0), custom_call_target="MultiUpdateAdd", backend_config="{\"index_vector_dim\":1,\"update_dim\":1}\n"
4451+
mu1 = f16[1024, 64] custom-call(p1_t, p2, p0), custom_call_target="MultiUpdate", backend_config="{\"index_vector_dim\":1,\"update_dim\":1,\"indices_are_sorted\":false}\n"
4452+
mu2 = f16[1024, 64] custom-call(mu1, p2, p0), custom_call_target="MultiUpdate", backend_config="{\"index_vector_dim\":1,\"update_dim\":1,\"indices_are_sorted\":false}\n"
44534453
ROOT t = (f16[1024, 64], f16[1024, 64]) tuple(mu1, mu2)
44544454
}
44554455
)";

0 commit comments

Comments
 (0)