@@ -14,11 +14,13 @@ limitations under the License.
1414==============================================================================*/
1515
1616#include " tensorflow/compiler/plugin/poplar/driver/tools/custom_ops/multi_slice.h"
17+
18+ #include < string>
19+
1720#include " tensorflow/compiler/plugin/poplar/driver/tools/hlo_poplar_buffer_util.h"
1821#include " tensorflow/compiler/plugin/poplar/driver/tools/matcher_predicates.h"
1922#include " tensorflow/compiler/plugin/poplar/kernels/custom_kernels_util.h"
2023#include " tensorflow/compiler/plugin/poplar/kernels/ops.pb.h"
21-
2224#include " tensorflow/compiler/xla/shape_util.h"
2325
2426namespace xla {
@@ -27,8 +29,10 @@ namespace poplarplugin {
2729// MultiSlice
2830HloMultiSliceInstruction::HloMultiSliceInstruction (
2931 const Shape& shape, HloInstruction* const input,
30- HloInstruction* const indices)
31- : HloPoplarInstruction(shape, {input, indices}, PoplarOp::MultiSlice) {}
32+ HloInstruction* const indices, bool indices_are_sorted)
33+ : HloPoplarInstruction(shape, {input, indices}, PoplarOp::MultiSlice,
34+ indices_are_sorted),
35+ indices_are_sorted_ (indices_are_sorted) {}
3236
3337absl::flat_hash_set<int64> HloMultiSliceInstruction::AllocatingIndices () const {
3438 return {0 , 1 };
@@ -61,7 +65,8 @@ std::unique_ptr<HloInstruction>
6165HloMultiSliceInstruction::CloneWithNewOperandsImpl (
6266 const Shape& shape, absl::Span<HloInstruction* const > new_operands,
6367 HloCloneContext*) const {
64- return CreateMultiSlice (shape, new_operands[0 ], new_operands[1 ]);
68+ return CreateMultiSlice (shape, new_operands[0 ], new_operands[1 ],
69+ indices_are_sorted_);
6570}
6671
6772std::vector<std::string>
@@ -70,24 +75,28 @@ HloMultiSliceInstruction::ExtraPoplarAttributesToStringImpl(
7075 return {};
7176}
7277
73- std::unique_ptr<HloInstruction> CreateMultiSlice (
74- const Shape& shape, HloInstruction* const input,
75- HloInstruction* const indices) {
76- return absl::make_unique<HloMultiSliceInstruction>(shape, input, indices);
78+ std::unique_ptr<HloInstruction> CreateMultiSlice (const Shape& shape,
79+ HloInstruction* const input,
80+ HloInstruction* const indices,
81+ bool indices_are_sorted) {
82+ return absl::make_unique<HloMultiSliceInstruction>(shape, input, indices,
83+ indices_are_sorted);
7784}
7885
7986// MultiUpdate
8087HloMultiUpdateInstruction::HloMultiUpdateInstruction (
8188 const Shape& shape, absl::Span<HloInstruction* const > operands,
8289 std::size_t index_vector_dim, std::size_t update_dim,
83- uint32 serialization_factor, bool is_update)
90+ uint32 serialization_factor, bool is_update, bool indices_are_sorted )
8491 : HloPoplarInstruction(
8592 shape, operands,
8693 is_update ? PoplarOp::MultiUpdateAdd : PoplarOp::MultiUpdate,
87- index_vector_dim, update_dim, serialization_factor),
94+ index_vector_dim, update_dim, serialization_factor,
95+ indices_are_sorted),
8896 index_vector_dim_ (index_vector_dim),
8997 update_dim_(update_dim),
90- serialization_factor_(serialization_factor) {}
98+ serialization_factor_(serialization_factor),
99+ indices_are_sorted_(indices_are_sorted) {}
91100
92101absl::flat_hash_set<int64> HloMultiUpdateInstruction::AllocatingIndices ()
93102 const {
@@ -138,7 +147,7 @@ HloMultiUpdateInstruction::CloneWithNewOperandsImpl(
138147 const Shape& shape, absl::Span<HloInstruction* const > new_operands,
139148 HloCloneContext*) const {
140149 return CreateMultiUpdate (shape, new_operands, index_vector_dim_, update_dim_,
141- serialization_factor_);
150+ serialization_factor_, indices_are_sorted_ );
142151}
143152
144153std::vector<std::string>
@@ -149,46 +158,55 @@ HloMultiUpdateInstruction::ExtraPoplarAttributesToStringImpl(
149158 attributes.push_back (" update_dim=" + std::to_string (update_dim_));
150159 attributes.push_back (" serialization_factor=" +
151160 std::to_string (serialization_factor_));
161+ attributes.push_back (" indices_are_sorted=" +
162+ std::string (indices_are_sorted_ ? " true" : " false" ));
152163 return attributes;
153164}
154165
155166std::unique_ptr<HloInstruction> CreateMultiUpdate (
156167 const Shape& shape, absl::Span<HloInstruction* const > operands,
157168 std::size_t index_vector_dim, std::size_t update_dim,
158- uint32 serialization_factor) {
169+ uint32 serialization_factor, bool indices_are_sorted ) {
159170 return absl::make_unique<HloMultiUpdateInstruction>(
160- shape, operands, index_vector_dim, update_dim, serialization_factor);
171+ shape, operands, index_vector_dim, update_dim, serialization_factor,
172+ indices_are_sorted);
161173}
162174
163175// MultiUpdateAdd
164176HloMultiUpdateAddInstruction::HloMultiUpdateAddInstruction (
165177 const Shape& shape, absl::Span<HloInstruction* const > operands,
166178 std::size_t index_vector_dim, std::size_t update_dim,
167- uint32 serialization_factor)
179+ uint32 serialization_factor, bool indices_are_sorted )
168180 : HloMultiUpdateInstruction(shape, operands, index_vector_dim, update_dim,
169- serialization_factor, true ) {}
181+ serialization_factor, true ,
182+ indices_are_sorted) {}
170183
171184std::unique_ptr<HloInstruction>
172185HloMultiUpdateAddInstruction::CloneWithNewOperandsImpl (
173186 const Shape& shape, absl::Span<HloInstruction* const > new_operands,
174187 HloCloneContext*) const {
175188 return CreateMultiUpdateAdd (shape, new_operands, index_vector_dim_,
176- update_dim_, serialization_factor_);
189+ update_dim_, serialization_factor_,
190+ indices_are_sorted_);
177191}
178192
179193std::unique_ptr<HloInstruction> CreateMultiUpdateAdd (
180194 const Shape& shape, absl::Span<HloInstruction* const > operands,
181195 std::size_t index_vector_dim, std::size_t update_dim,
182- uint32 serialization_factor) {
196+ uint32 serialization_factor, bool indices_are_sorted ) {
183197 return absl::make_unique<HloMultiUpdateAddInstruction>(
184- shape, operands, index_vector_dim, update_dim, serialization_factor);
198+ shape, operands, index_vector_dim, update_dim, serialization_factor,
199+ indices_are_sorted);
185200}
186201
187202namespace {
188203StatusOr<std::unique_ptr<HloInstruction>> HloMultiSliceInstructionFactoryFunc (
189204 HloCustomCallInstruction* call) {
205+ auto attribute_map = IPUCustomKernelsUtil::AttributeMap (call);
206+ TF_ASSIGN_OR_RETURN (bool indices_are_sorted,
207+ attribute_map.GetAttributeAsBool (" indices_are_sorted" ));
190208 return CreateMultiSlice (call->shape (), call->mutable_operand (0 ),
191- call->mutable_operand (1 ));
209+ call->mutable_operand (1 ), indices_are_sorted );
192210}
193211
194212StatusOr<std::unique_ptr<HloInstruction>> HloMultiUpdateInstructionFactoryFunc (
@@ -198,8 +216,10 @@ StatusOr<std::unique_ptr<HloInstruction>> HloMultiUpdateInstructionFactoryFunc(
198216 attribute_map.GetAttributeAsUInt64 (" index_vector_dim" ));
199217 TF_ASSIGN_OR_RETURN (uint64 update_dim,
200218 attribute_map.GetAttributeAsUInt64 (" update_dim" ));
219+ TF_ASSIGN_OR_RETURN (bool indices_are_sorted,
220+ attribute_map.GetAttributeAsBool (" indices_are_sorted" ));
201221 return CreateMultiUpdate (call->shape (), call->operands (), index_vector_dim,
202- update_dim, 1 );
222+ update_dim, 1 , indices_are_sorted );
203223}
204224
205225StatusOr<std::unique_ptr<HloInstruction>>
@@ -209,8 +229,10 @@ HloMultiUpdateAddInstructionFactoryFunc(HloCustomCallInstruction* call) {
209229 attribute_map.GetAttributeAsUInt64 (" index_vector_dim" ));
210230 TF_ASSIGN_OR_RETURN (uint64 update_dim,
211231 attribute_map.GetAttributeAsUInt64 (" update_dim" ));
232+ TF_ASSIGN_OR_RETURN (bool indices_are_sorted,
233+ attribute_map.GetAttributeAsBool (" indices_are_sorted" ));
212234 return CreateMultiUpdateAdd (call->shape (), call->operands (), index_vector_dim,
213- update_dim, 1 );
235+ update_dim, 1 , indices_are_sorted );
214236}
215237
216238static HloPoplarInstructionFactory multi_slice_factory (
0 commit comments