@@ -29,9 +29,11 @@ limitations under the License.
2929#include " tensorflow/compiler/plugin/poplar/driver/tools/matcher_predicates.h"
3030#include " tensorflow/compiler/plugin/poplar/driver/tools/util.h"
3131#include " tensorflow/compiler/plugin/poplar/kernels/ops.pb.h"
32+
3233#include " tensorflow/compiler/xla/literal_util.h"
3334#include " tensorflow/compiler/xla/service/call_graph.h"
3435#include " tensorflow/compiler/xla/service/hlo_casting_utils.h"
36+ #include " tensorflow/compiler/xla/service/hlo_creation_utils.h"
3537#include " tensorflow/compiler/xla/shape_util.h"
3638
3739namespace xla {
@@ -280,8 +282,6 @@ HloComputation* ReplaceResourceUpdateFunction(
280282 CreateMultiUpdateAdd (old_sink_shape,
281283 {zero_broadcast.get (), new_indices_reshape.get (),
282284 new_grads_arg.get (), const_1.get ()},
283- multi_update_add->GetIndexVectorDimension (),
284- multi_update_add->GetUpdateSliceDimension (),
285285 multi_update_add->GetSerializationFactor ());
286286
287287 for (auto old_sink_user : old_sink_arg->users ()) {
@@ -355,8 +355,15 @@ StatusOr<HloComputation*> ReplaceAccumulatorCaller(
355355 ShapeUtil::MakeShape (update_index->shape ().element_type (),
356356 {plan.mini_batch_size }),
357357 update_index.get (), {});
358- auto update_index_broadcast_indices = HloInstruction::CreateBroadcast (
359- ShapeUtil::MakeShape (update_index->shape ().element_type (), {1 }),
358+
359+ auto update_index_broadcast_grads_reshaped = HloInstruction::CreateReshape (
360+ ShapeUtil::MakeShape (update_index->shape ().element_type (),
361+ /* dimensions=*/ {plan.mini_batch_size , 1 }),
362+ update_index_broadcast_grads.get (), {});
363+
364+ auto update_index_broadcast_indices_reshaped = HloInstruction::CreateReshape (
365+ ShapeUtil::MakeShape (update_index->shape ().element_type (),
366+ /* dimensions=*/ {1 , 1 }),
360367 update_index.get (), {});
361368
362369 auto const_int_1 = HloInstruction::CreateConstant (LiteralUtil::One (S32));
@@ -369,15 +376,19 @@ StatusOr<HloComputation*> ReplaceAccumulatorCaller(
369376 plan.accum_grads_shape ,
370377 {pipeline_stage_accum_grads_param ? pipeline_stage_accum_grads_param.get ()
371378 : accum_grads.get (),
372- update_index_broadcast_grads.get (), grads, scale},
373- 1 , 1 , multi_update_add->GetSerializationFactor ());
374- auto indices_update = CreateMultiUpdateAdd (
375- plan.accum_indices_shape ,
376- {pipeline_stage_accum_indices_param
377- ? pipeline_stage_accum_indices_param.get ()
378- : accum_indices.get (),
379- update_index_broadcast_indices.get (), indices, const_int_1.get ()},
380- 0 , 0 , multi_update_add->GetSerializationFactor ());
379+ update_index_broadcast_grads_reshaped.get (), grads, scale},
380+ multi_update_add->GetSerializationFactor ());
381+
382+ TF_ASSIGN_OR_RETURN (auto indices_transpose,
383+ MakeTransposeHlo (indices, {1 , 0 }));
384+ auto indices_update =
385+ CreateMultiUpdateAdd (plan.accum_indices_shape ,
386+ {pipeline_stage_accum_indices_param
387+ ? pipeline_stage_accum_indices_param.get ()
388+ : accum_indices.get (),
389+ update_index_broadcast_indices_reshaped.get (),
390+ indices_transpose, const_int_1.get ()},
391+ multi_update_add->GetSerializationFactor ());
381392
382393 std::unique_ptr<HloInstruction> grads_update_gte, indices_update_gte;
383394 if (pipeline_stage) {
@@ -525,6 +536,11 @@ absl::optional<Candidate> CheckEmbeddingsCandidate(HloInstruction* inst) {
525536 " found." ;
526537 return absl::nullopt ;
527538 }
539+ // Updates need to be 2D.
540+ if (inst->shape ().rank () != 2 ) {
541+ VLOG (2 ) << " The shape needs to be 2D." ;
542+ return absl::nullopt ;
543+ }
528544 return Candidate{Cast<HloGradientAccumulatorSink>(inst), {}};
529545}
530546
@@ -627,6 +643,11 @@ absl::optional<PipelineCandidate> CheckPipelineEmbeddingsCandidate(
627643 return absl::nullopt ;
628644 }
629645
646+ if (grad_create->shape ().rank () != 2 ) {
647+ VLOG (2 ) << " The shape needs to be 2D." ;
648+ return absl::nullopt ;
649+ }
650+
630651 PipelineCandidate candidate{grad_create, grad_add, multi_update_add,
631652 grad_sink, grad_sink_gte, pipeline_stage,
632653 resource_update};
0 commit comments