Skip to content

Commit 5b27a0a

Browse files
Han Zhaogeorgepaw
authored andcommitted
Simplify concatenate of same slices
Summary: Replace concatenate of same slices with broadcast + reshape Test Plan: CI Reviewers: #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, vladimirm, jackh, jakeh Reviewed By: #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, vladimirm, jackh, jakeh Subscribers: jakeh Maniphest Tasks: T57213 Differential Revision: https://phabricator.sourcevertex.net/D62112
1 parent c7086ef commit 5b27a0a

File tree

3 files changed

+271
-0
lines changed

3 files changed

+271
-0
lines changed

tensorflow/compiler/plugin/poplar/driver/passes/poplar_algebraic_simplifier.cc

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,120 @@ Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy) {
398398
return Status::OK();
399399
}
400400

401+
// Concatenate same slices can be replaced by broadcast + reshape
402+
// +-------------+
403+
// | Operand |
404+
// +-------------+
405+
// | |
406+
// | |
407+
// v v
408+
// +----------+ +----------+
409+
// | Slice | | Slice |
410+
// +----------+ +----------+
411+
// | | | | | | | |
412+
// | | | | | | | |
413+
// v v v v v v v v
414+
// +---------------------------+
415+
// | Concatenate |
416+
// +---------------------------+
417+
bool AlgebraicSimplifierVisitor::TrySimplifyConcatenateOfSameSlices(
418+
HloInstruction* concatenate) {
419+
absl::Span<HloInstruction* const> operands(concatenate->operands());
420+
int64 concatenate_dimension = concatenate->concatenate_dimension();
421+
// Make sure all the operands are slice, and all the slices are from the
422+
// same op
423+
for (int64 i = 0; i < static_cast<int64>(operands.size()); ++i) {
424+
// Make sure all the operands are the slice
425+
if (operands[i]->opcode() != HloOpcode::kSlice ||
426+
!pp::algebraic_simplifier::util::IsUnstridedSlice(operands[i])) {
427+
return false;
428+
}
429+
430+
// Make sure all the slices are from the same op
431+
if (operands[i]->mutable_operand(0) != operands[0]->mutable_operand(0)) {
432+
return false;
433+
}
434+
}
435+
436+
// Record the same slice
437+
bool has_same_slices = false;
438+
std::vector<std::pair<HloInstruction*, int64>> slice_cnt;
439+
for (auto current = operands.begin(); current != operands.end();) {
440+
// Find the end of the matching slices.
441+
auto next = std::partition_point(
442+
current, operands.end(), [&current](HloInstruction* operand) -> bool {
443+
return (((*current)->slice_starts() == operand->slice_starts()) &&
444+
((*current)->slice_limits() == operand->slice_limits()));
445+
});
446+
447+
// Recording the number of matching slices and whether we've seen a group
448+
// larger than 1.
449+
slice_cnt.push_back({*current, std::distance(current, next)});
450+
has_same_slices |= (slice_cnt.back().second > 1);
451+
452+
// Move onto the next potential slice group.
453+
current = next;
454+
}
455+
if (!has_same_slices) {
456+
return false;
457+
}
458+
459+
// Replace same slices with broadcast + reshape
460+
std::vector<HloInstruction*> new_operands;
461+
for (const auto& pair : slice_cnt) {
462+
if (pair.second == 1) {
463+
new_operands.push_back(pair.first);
464+
} else {
465+
// Create broadcast
466+
// Create the broadcast output shape.
467+
HloInstruction* bcast_operand = pair.first;
468+
auto shape_dims = bcast_operand->shape().dimensions();
469+
std::vector<int64> new_shape_dims(shape_dims.begin(), shape_dims.end());
470+
new_shape_dims.insert(new_shape_dims.begin() + concatenate_dimension + 1,
471+
pair.second);
472+
Shape bcast_shape = ShapeUtil::MakeShape(
473+
bcast_operand->shape().element_type(), new_shape_dims);
474+
475+
// Create the dimensions for broadcast
476+
std::vector<int64> bcast_dims(bcast_operand->shape().dimensions().size());
477+
std::iota(bcast_dims.begin(),
478+
bcast_dims.begin() + concatenate_dimension + 1, 0);
479+
std::iota(bcast_dims.begin() + concatenate_dimension + 1,
480+
bcast_dims.end(), concatenate_dimension + 2);
481+
482+
HloInstruction* broadcast =
483+
computation_->AddInstruction(HloInstruction::CreateBroadcast(
484+
bcast_shape, bcast_operand, bcast_dims));
485+
486+
// Create reshape
487+
auto reshape_dims = bcast_operand->shape().dimensions();
488+
std::vector<int64> new_reshape_dims(reshape_dims.begin(),
489+
reshape_dims.end());
490+
new_reshape_dims[concatenate_dimension] *= pair.second;
491+
Shape reshape_shape = ShapeUtil::MakeShape(
492+
bcast_operand->shape().element_type(), new_reshape_dims);
493+
494+
HloInstruction* reshape =
495+
computation_->AddInstruction(HloInstruction::CreateReshape(
496+
reshape_shape, broadcast, {concatenate_dimension}));
497+
new_operands.push_back(reshape);
498+
}
499+
}
500+
501+
if (new_operands.size() == 1) {
502+
// No need to add the concatenate if the inputs to the concatenate are all
503+
// from the same slice
504+
ReplaceInstructionIfSameShape(concatenate, new_operands[0]);
505+
} else {
506+
// Replace with new concatenate op
507+
HloInstruction* new_concatenate = computation_->AddInstruction(
508+
concatenate->CloneWithNewOperands(concatenate->shape(), new_operands));
509+
ReplaceInstructionIfSameShape(concatenate, new_concatenate);
510+
}
511+
512+
return true;
513+
}
514+
401515
Status AlgebraicSimplifierVisitor::HandleConcatenate(
402516
HloInstruction* concatenate) {
403517
absl::Span<HloInstruction* const> operands(concatenate->operands());
@@ -454,6 +568,13 @@ Status AlgebraicSimplifierVisitor::HandleConcatenate(
454568
return Status::OK();
455569
}
456570

571+
// Concatenate same slices can be replaced by broadcast + reshape
572+
VLOG(10)
573+
<< "Trying to replace a concatenate of matching slices with a broadcast";
574+
if (TrySimplifyConcatenateOfSameSlices(concatenate)) {
575+
return Status::OK();
576+
}
577+
457578
// Check if we can merge "adjacent" slice operands which take slices from the
458579
// same other op. For simplicity we only merge unstrided slices.
459580
int64 concatenate_dimension = concatenate->concatenate_dimension();

tensorflow/compiler/plugin/poplar/driver/passes/poplar_algebraic_simplifier.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,9 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor {
226226
// Useful when we want to use the same visitor over multiple computations.
227227
void ResetState(HloComputation* computation);
228228

229+
// Tries to replace concatenate of same slices to broadcast.
230+
bool TrySimplifyConcatenateOfSameSlices(HloInstruction* concatenate);
231+
229232
// Current HloComputation instance the AlgebraicSimplifierVisitor is
230233
// traversing.
231234
HloComputation* computation_;

tensorflow/compiler/plugin/poplar/tests/poplar_algebraic_simplifier_test.cc

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2121,6 +2121,153 @@ TEST_F(PoplarAlgebraicSimplifierTest, SimplifyConcatenateOfSlices) {
21212121
EXPECT_EQ(computation->root_instruction()->operand(3)->slice_starts(1), 40);
21222122
}
21232123

2124+
TEST_F(PoplarAlgebraicSimplifierTest, SimplifyConcatenateOfSameSlices_1_2_2_1) {
2125+
auto m = CreateNewVerifiedModule();
2126+
Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 4});
2127+
Shape concat_shape = ShapeUtil::MakeShape(F32, {3, 6});
2128+
HloComputation::Builder builder(TestName());
2129+
HloInstruction* param0 = builder.AddInstruction(
2130+
HloInstruction::CreateParameter(0, r2f32, "param0"));
2131+
2132+
HloInstruction* slice0 = builder.AddInstruction(HloInstruction::CreateSlice(
2133+
ShapeUtil::MakeShape(F32, {3, 1}), param0, /*start_indices=*/{0, 0},
2134+
/*limit_indices=*/{3, 1}, /*strides=*/{1, 1}));
2135+
2136+
HloInstruction* slice1 = builder.AddInstruction(HloInstruction::CreateSlice(
2137+
ShapeUtil::MakeShape(F32, {3, 1}), param0, /*start_indices=*/{0, 1},
2138+
/*limit_indices=*/{3, 2}, /*strides=*/{1, 1}));
2139+
2140+
HloInstruction* slice2 = builder.AddInstruction(HloInstruction::CreateSlice(
2141+
ShapeUtil::MakeShape(F32, {3, 1}), param0, /*start_indices=*/{0, 2},
2142+
/*limit_indices=*/{3, 3}, /*strides=*/{1, 1}));
2143+
2144+
HloInstruction* slice3 = builder.AddInstruction(HloInstruction::CreateSlice(
2145+
ShapeUtil::MakeShape(F32, {3, 1}), param0, /*start_indices=*/{0, 3},
2146+
/*limit_indices=*/{3, 4}, /*strides=*/{1, 1}));
2147+
2148+
builder.AddInstruction(HloInstruction::CreateConcatenate(
2149+
concat_shape, {slice0, slice1, slice1, slice2, slice2, slice3}, 1));
2150+
auto computation = m->AddEntryComputation(builder.Build());
2151+
2152+
PoplarAlgebraicSimplifier simplifier;
2153+
ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2154+
2155+
auto root = computation->root_instruction();
2156+
EXPECT_THAT(root, GmockMatch(m::Concatenate(
2157+
m::Slice(m::Parameter(0)),
2158+
m::Reshape(m::Broadcast(m::Slice(m::Parameter(0)))),
2159+
m::Reshape(m::Broadcast(m::Slice(m::Parameter(0)))),
2160+
m::Slice(m::Parameter(0)))));
2161+
// Check shape of broadcast
2162+
EXPECT_TRUE(ShapeUtil::Equal(root->operand(1)->operand(0)->shape(),
2163+
ShapeUtil::MakeShape(F32, {3, 1, 2})));
2164+
// check slices start/limit indices
2165+
EXPECT_THAT(root->operand(0)->slice_starts(), ElementsAre(0, 0));
2166+
EXPECT_THAT(root->operand(0)->slice_limits(), ElementsAre(3, 1));
2167+
EXPECT_THAT(root->operand(1)->operand(0)->operand(0)->slice_starts(),
2168+
ElementsAre(0, 1));
2169+
EXPECT_THAT(root->operand(1)->operand(0)->operand(0)->slice_limits(),
2170+
ElementsAre(3, 2));
2171+
EXPECT_THAT(root->operand(2)->operand(0)->operand(0)->slice_starts(),
2172+
ElementsAre(0, 2));
2173+
EXPECT_THAT(root->operand(2)->operand(0)->operand(0)->slice_limits(),
2174+
ElementsAre(3, 3));
2175+
EXPECT_THAT(root->operand(3)->slice_starts(), ElementsAre(0, 3));
2176+
EXPECT_THAT(root->operand(3)->slice_limits(), ElementsAre(3, 4));
2177+
}
2178+
2179+
TEST_F(PoplarAlgebraicSimplifierTest, SimplifyConcatenateOfSameSlices_5_5_5_5) {
2180+
auto m = CreateNewVerifiedModule();
2181+
Shape r2f32 = ShapeUtil::MakeShape(F32, {1, 4});
2182+
Shape concat_shape = ShapeUtil::MakeShape(F32, {1, 20});
2183+
HloComputation::Builder builder(TestName());
2184+
HloInstruction* param0 = builder.AddInstruction(
2185+
HloInstruction::CreateParameter(0, r2f32, "param0"));
2186+
2187+
HloInstruction* slice0 = builder.AddInstruction(HloInstruction::CreateSlice(
2188+
ShapeUtil::MakeShape(F32, {1, 1}), param0, /*start_indices=*/{0, 0},
2189+
/*limit_indices=*/{1, 1}, /*strides=*/{1, 1}));
2190+
2191+
HloInstruction* slice1 = builder.AddInstruction(HloInstruction::CreateSlice(
2192+
ShapeUtil::MakeShape(F32, {1, 1}), param0, /*start_indices=*/{0, 1},
2193+
/*limit_indices=*/{1, 2}, /*strides=*/{1, 1}));
2194+
2195+
HloInstruction* slice2 = builder.AddInstruction(HloInstruction::CreateSlice(
2196+
ShapeUtil::MakeShape(F32, {1, 1}), param0, /*start_indices=*/{0, 2},
2197+
/*limit_indices=*/{1, 3}, /*strides=*/{1, 1}));
2198+
2199+
HloInstruction* slice3 = builder.AddInstruction(HloInstruction::CreateSlice(
2200+
ShapeUtil::MakeShape(F32, {1, 1}), param0, /*start_indices=*/{0, 3},
2201+
/*limit_indices=*/{1, 4}, /*strides=*/{1, 1}));
2202+
2203+
builder.AddInstruction(HloInstruction::CreateConcatenate(
2204+
concat_shape, {slice0, slice0, slice0, slice0, slice0, slice1, slice1,
2205+
slice1, slice1, slice1, slice2, slice2, slice2, slice2,
2206+
slice2, slice3, slice3, slice3, slice3, slice3},
2207+
1));
2208+
auto computation = m->AddEntryComputation(builder.Build());
2209+
2210+
PoplarAlgebraicSimplifier simplifier;
2211+
ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2212+
2213+
auto root = computation->root_instruction();
2214+
EXPECT_THAT(root, GmockMatch(m::Concatenate(
2215+
m::Reshape(m::Broadcast(m::Slice(m::Parameter(0)))),
2216+
m::Reshape(m::Broadcast(m::Slice(m::Parameter(0)))),
2217+
m::Reshape(m::Broadcast(m::Slice(m::Parameter(0)))),
2218+
m::Reshape(m::Broadcast(m::Slice(m::Parameter(0)))))));
2219+
EXPECT_TRUE(ShapeUtil::Equal(root->operand(0)->operand(0)->shape(),
2220+
ShapeUtil::MakeShape(F32, {1, 1, 5})));
2221+
EXPECT_THAT(root->operand(0)->operand(0)->operand(0)->slice_starts(),
2222+
ElementsAre(0, 0));
2223+
EXPECT_THAT(root->operand(0)->operand(0)->operand(0)->slice_limits(),
2224+
ElementsAre(1, 1));
2225+
EXPECT_THAT(root->operand(1)->operand(0)->operand(0)->slice_starts(),
2226+
ElementsAre(0, 1));
2227+
EXPECT_THAT(root->operand(1)->operand(0)->operand(0)->slice_limits(),
2228+
ElementsAre(1, 2));
2229+
EXPECT_THAT(root->operand(2)->operand(0)->operand(0)->slice_starts(),
2230+
ElementsAre(0, 2));
2231+
EXPECT_THAT(root->operand(2)->operand(0)->operand(0)->slice_limits(),
2232+
ElementsAre(1, 3));
2233+
EXPECT_THAT(root->operand(3)->operand(0)->operand(0)->slice_starts(),
2234+
ElementsAre(0, 3));
2235+
EXPECT_THAT(root->operand(3)->operand(0)->operand(0)->slice_limits(),
2236+
ElementsAre(1, 4));
2237+
}
2238+
2239+
TEST_F(PoplarAlgebraicSimplifierTest, SimplifyConcatenateOfSameSlices_1x5) {
2240+
auto m = CreateNewVerifiedModule();
2241+
Shape r2f32 = ShapeUtil::MakeShape(F32, {1, 4});
2242+
Shape concat_shape = ShapeUtil::MakeShape(F32, {1, 5});
2243+
HloComputation::Builder builder(TestName());
2244+
HloInstruction* param0 = builder.AddInstruction(
2245+
HloInstruction::CreateParameter(0, r2f32, "param0"));
2246+
2247+
HloInstruction* slice0 = builder.AddInstruction(HloInstruction::CreateSlice(
2248+
ShapeUtil::MakeShape(F32, {1, 1}), param0, /*start_indices=*/{0, 0},
2249+
/*limit_indices=*/{1, 1}, /*strides=*/{1, 1}));
2250+
2251+
builder.AddInstruction(HloInstruction::CreateConcatenate(
2252+
concat_shape, {slice0, slice0, slice0, slice0, slice0}, 1));
2253+
auto computation = m->AddEntryComputation(builder.Build());
2254+
2255+
PoplarAlgebraicSimplifier simplifier;
2256+
ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2257+
2258+
auto root = computation->root_instruction();
2259+
2260+
// There shouldn't be the concatenate if the inputs to the concatenate are all
2261+
// from the same slice
2262+
EXPECT_THAT(root,
2263+
GmockMatch(m::Reshape(m::Broadcast(m::Slice(m::Parameter(0))))));
2264+
2265+
EXPECT_TRUE(ShapeUtil::Equal(root->operand(0)->shape(),
2266+
ShapeUtil::MakeShape(F32, {1, 1, 5})));
2267+
EXPECT_THAT(root->operand(0)->operand(0)->slice_starts(), ElementsAre(0, 0));
2268+
EXPECT_THAT(root->operand(0)->operand(0)->slice_limits(), ElementsAre(1, 1));
2269+
}
2270+
21242271
// Test transforming reshapes and transposes of rng.
21252272
TEST_F(PoplarAlgebraicSimplifierTest, ReshapeOfTransposeOfRngToRng) {
21262273
auto m = CreateNewVerifiedModule();

0 commit comments

Comments
 (0)