Skip to content

Commit 436cc37

Browse files
committed
Simplify multi update instruction
Summary: TF2.4 Only Ref T45278 Test Plan: CI Reviewers: #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, vladimirm, jakeh, georgew Reviewed By: #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, georgew Maniphest Tasks: T45278 Differential Revision: https://phabricator.sourcevertex.net/D53082
1 parent 677fc69 commit 436cc37

28 files changed

+239
-575
lines changed

tensorflow/compiler/plugin/poplar/BUILD

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,6 @@ cc_library(
226226
"driver/passes/multi_conv_fixer.cc",
227227
"driver/passes/multi_slice_combiner.cc",
228228
"driver/passes/multi_update_apply.cc",
229-
"driver/passes/multi_update_canonicalize.cc",
230229
"driver/passes/multi_update_combiner.cc",
231230
"driver/passes/multi_update_scale_apply.cc",
232231
"driver/passes/multi_use_feeds_finder.cc",
@@ -435,7 +434,6 @@ cc_library(
435434
"driver/passes/multi_conv_fixer.h",
436435
"driver/passes/multi_slice_combiner.h",
437436
"driver/passes/multi_update_apply.h",
438-
"driver/passes/multi_update_canonicalize.h",
439437
"driver/passes/multi_update_combiner.h",
440438
"driver/passes/multi_update_scale_apply.h",
441439
"driver/passes/multi_use_feeds_finder.h",
@@ -3121,22 +3119,6 @@ xla_test(
31213119
],
31223120
)
31233121

3124-
xla_test(
3125-
name = "multi_update_canonicalize_test",
3126-
size = "small",
3127-
srcs = ["tests/multi_update_canonicalize_test.cc"],
3128-
backends = ["poplar"],
3129-
copts = ["-fexceptions"],
3130-
deps = [
3131-
":optimizers",
3132-
"//tensorflow/compiler/xla/service:hlo_matchers",
3133-
"//tensorflow/compiler/xla/service:hlo_query",
3134-
"//tensorflow/compiler/xla/service:shape_inference",
3135-
"//tensorflow/compiler/xla/tests:hlo_test_base",
3136-
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
3137-
],
3138-
)
3139-
31403122
xla_test(
31413123
name = "resource_update_fixer_test",
31423124
size = "small",
@@ -5848,7 +5830,6 @@ test_suite(
58485830
"multi_run_test",
58495831
"multi_slice_combiner_test",
58505832
"multi_update_apply_test",
5851-
"multi_update_canonicalize_test",
58525833
"multi_update_combiner_test",
58535834
"multi_update_scale_apply_test",
58545835
"multi_use_feeds_finder_test",

tensorflow/compiler/plugin/poplar/driver/passes/all_to_all_finder.cc

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -161,9 +161,6 @@ static Status ApplyTransformation(HloMatcherMatched& match,
161161

162162
HloMultiUpdateInstruction* multi_update =
163163
Cast<HloMultiUpdateInstruction>(match.instruction_mapping[2]);
164-
// This pass should have been run after MultiUpdateCanonicalize.
165-
CHECK_EQ(multi_update->GetIndexVectorDimension(), 1);
166-
CHECK_EQ(multi_update->GetUpdateSliceDimension(), 1);
167164

168165
HloInstruction* broadcast = match.instruction_mapping[3];
169166
HloInstruction* indices = match.instruction_mapping[5];
@@ -199,14 +196,14 @@ static Status ApplyTransformation(HloMatcherMatched& match,
199196
HloInstruction* scale = match.instruction_mapping[7];
200197
output = comp->AddInstruction(CreateMultiUpdateAdd(
201198
broadcast->shape(),
202-
{broadcast, reduced_indices, normalized_updates, scale}, 1, 1,
199+
{broadcast, reduced_indices, normalized_updates, scale},
203200
serialization_factor));
204201
} else {
205202
// Create MultiUpdate.
206203
CHECK_EQ(match.pattern_idx, 1);
207204
output = comp->AddInstruction(CreateMultiUpdate(
208-
broadcast->shape(), {broadcast, reduced_indices, normalized_updates}, 1,
209-
1, serialization_factor));
205+
broadcast->shape(), {broadcast, reduced_indices, normalized_updates},
206+
serialization_factor));
210207
}
211208

212209
// Replace with the new output.

tensorflow/compiler/plugin/poplar/driver/passes/embeddings_gradient_optimizer.cc

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,11 @@ limitations under the License.
2929
#include "tensorflow/compiler/plugin/poplar/driver/tools/matcher_predicates.h"
3030
#include "tensorflow/compiler/plugin/poplar/driver/tools/util.h"
3131
#include "tensorflow/compiler/plugin/poplar/kernels/ops.pb.h"
32+
3233
#include "tensorflow/compiler/xla/literal_util.h"
3334
#include "tensorflow/compiler/xla/service/call_graph.h"
3435
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
36+
#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
3537
#include "tensorflow/compiler/xla/shape_util.h"
3638

3739
namespace xla {
@@ -280,8 +282,6 @@ HloComputation* ReplaceResourceUpdateFunction(
280282
CreateMultiUpdateAdd(old_sink_shape,
281283
{zero_broadcast.get(), new_indices_reshape.get(),
282284
new_grads_arg.get(), const_1.get()},
283-
multi_update_add->GetIndexVectorDimension(),
284-
multi_update_add->GetUpdateSliceDimension(),
285285
multi_update_add->GetSerializationFactor());
286286

287287
for (auto old_sink_user : old_sink_arg->users()) {
@@ -355,8 +355,15 @@ StatusOr<HloComputation*> ReplaceAccumulatorCaller(
355355
ShapeUtil::MakeShape(update_index->shape().element_type(),
356356
{plan.mini_batch_size}),
357357
update_index.get(), {});
358-
auto update_index_broadcast_indices = HloInstruction::CreateBroadcast(
359-
ShapeUtil::MakeShape(update_index->shape().element_type(), {1}),
358+
359+
auto update_index_broadcast_grads_reshaped = HloInstruction::CreateReshape(
360+
ShapeUtil::MakeShape(update_index->shape().element_type(),
361+
/*dimensions=*/{plan.mini_batch_size, 1}),
362+
update_index_broadcast_grads.get(), {});
363+
364+
auto update_index_broadcast_indices_reshaped = HloInstruction::CreateReshape(
365+
ShapeUtil::MakeShape(update_index->shape().element_type(),
366+
/*dimensions=*/{1, 1}),
360367
update_index.get(), {});
361368

362369
auto const_int_1 = HloInstruction::CreateConstant(LiteralUtil::One(S32));
@@ -369,15 +376,19 @@ StatusOr<HloComputation*> ReplaceAccumulatorCaller(
369376
plan.accum_grads_shape,
370377
{pipeline_stage_accum_grads_param ? pipeline_stage_accum_grads_param.get()
371378
: accum_grads.get(),
372-
update_index_broadcast_grads.get(), grads, scale},
373-
1, 1, multi_update_add->GetSerializationFactor());
374-
auto indices_update = CreateMultiUpdateAdd(
375-
plan.accum_indices_shape,
376-
{pipeline_stage_accum_indices_param
377-
? pipeline_stage_accum_indices_param.get()
378-
: accum_indices.get(),
379-
update_index_broadcast_indices.get(), indices, const_int_1.get()},
380-
0, 0, multi_update_add->GetSerializationFactor());
379+
update_index_broadcast_grads_reshaped.get(), grads, scale},
380+
multi_update_add->GetSerializationFactor());
381+
382+
TF_ASSIGN_OR_RETURN(auto indices_transpose,
383+
MakeTransposeHlo(indices, {1, 0}));
384+
auto indices_update =
385+
CreateMultiUpdateAdd(plan.accum_indices_shape,
386+
{pipeline_stage_accum_indices_param
387+
? pipeline_stage_accum_indices_param.get()
388+
: accum_indices.get(),
389+
update_index_broadcast_indices_reshaped.get(),
390+
indices_transpose, const_int_1.get()},
391+
multi_update_add->GetSerializationFactor());
381392

382393
std::unique_ptr<HloInstruction> grads_update_gte, indices_update_gte;
383394
if (pipeline_stage) {
@@ -525,6 +536,11 @@ absl::optional<Candidate> CheckEmbeddingsCandidate(HloInstruction* inst) {
525536
"found.";
526537
return absl::nullopt;
527538
}
539+
// Updates need to be 2D.
540+
if (inst->shape().rank() != 2) {
541+
VLOG(2) << "The shape needs to be 2D.";
542+
return absl::nullopt;
543+
}
528544
return Candidate{Cast<HloGradientAccumulatorSink>(inst), {}};
529545
}
530546

@@ -627,6 +643,11 @@ absl::optional<PipelineCandidate> CheckPipelineEmbeddingsCandidate(
627643
return absl::nullopt;
628644
}
629645

646+
if (grad_create->shape().rank() != 2) {
647+
VLOG(2) << "The shape needs to be 2D.";
648+
return absl::nullopt;
649+
}
650+
630651
PipelineCandidate candidate{grad_create, grad_add, multi_update_add,
631652
grad_sink, grad_sink_gte, pipeline_stage,
632653
resource_update};

tensorflow/compiler/plugin/poplar/driver/passes/multi_update_apply.cc

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -242,8 +242,6 @@ StatusOr<HloInstruction*> CreateNewMultiUpdate(
242242
// Create a new multi update add instruction.
243243
HloInstruction* new_multi_update_add = comp->AddInstruction(
244244
CreateMultiUpdateAdd(operand->shape(), {operand, indices, updates, scale},
245-
multi_update->GetIndexVectorDimension(),
246-
multi_update->GetUpdateSliceDimension(),
247245
multi_update->GetSerializationFactor()));
248246

249247
if (shard) {

tensorflow/compiler/plugin/poplar/driver/passes/multi_update_canonicalize.cc

Lines changed: 0 additions & 180 deletions
This file was deleted.

tensorflow/compiler/plugin/poplar/driver/passes/multi_update_canonicalize.h

Lines changed: 0 additions & 39 deletions
This file was deleted.

0 commit comments

Comments
 (0)