Skip to content

Commit 083c085

Browse files
whoozlegeorgepaw
authored andcommitted
Simplify all-reduce(replication-normalise(all-reduce))
Summary: This expression is effectively just all-reduce(). This simplification significantly reduces RTS memory footprint. Fix T43594 Test Plan: CI, new test Reviewers: #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, jakeh, georgep, hakons Reviewed By: #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, georgep Maniphest Tasks: T43594 Differential Revision: https://phabricator.sourcevertex.net/D49231
1 parent f13daef commit 083c085

File tree

2 files changed

+60
-0
lines changed

2 files changed

+60
-0
lines changed

tensorflow/compiler/plugin/poplar/driver/passes/poplar_algebraic_simplifier.cc

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,20 @@ StatusOr<HloInstruction*> PreserveFrontendAttributesIfNeeded(
171171
return new_inst;
172172
}
173173

174+
bool IsGlobalAllReduceWithSum(const HloInstruction* all_reduce) {
175+
if (all_reduce->opcode() != HloOpcode::kAllReduce ||
176+
!all_reduce->replica_groups().empty()) {
177+
return false;
178+
}
179+
auto& called_computations = all_reduce->called_computations();
180+
if (called_computations.size() != 1) {
181+
return false;
182+
}
183+
const HloComputation* comp = called_computations.front();
184+
return Match(comp->root_instruction(),
185+
m::Add(m::Parameter(), m::Parameter()));
186+
}
187+
174188
// AlgebraicSimplifierVisitor traverses the HLO computation and reduces certain
175189
// algebraic expressions to simplified forms. Note: This only supports
176190
// simplifications that simply look at the operands of an instruction. For the
@@ -183,6 +197,8 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor {
183197

184198
Status HandleAdd(HloInstruction* add) override;
185199

200+
Status HandleAllReduce(HloInstruction* all_reduce) override;
201+
186202
Status HandleAnd(HloInstruction* logical_and) override;
187203

188204
Status HandleBitcast(HloInstruction* bitcast) override;
@@ -3413,6 +3429,27 @@ Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice(
34133429
return Status::OK();
34143430
}
34153431

3432+
Status AlgebraicSimplifierVisitor::HandleAllReduce(HloInstruction* all_reduce) {
3433+
/// Replace all-reduce(replication-normalise(all-reduce(arg))) with
3434+
/// all-reduce(arg)
3435+
if (all_reduce->operand_count() != 1 ||
3436+
!IsGlobalAllReduceWithSum(all_reduce)) {
3437+
return Status::OK();
3438+
}
3439+
HloInstruction* normalise = all_reduce->mutable_operand(0);
3440+
if (!pp::IsPoplarInstruction(PoplarOp::ReplicationNormalise, normalise) ||
3441+
normalise->operand_count() != 1) {
3442+
return Status::OK();
3443+
}
3444+
HloInstruction* top_all_reduce = normalise->mutable_operand(0);
3445+
if (top_all_reduce->opcode() == HloOpcode::kAllReduce &&
3446+
top_all_reduce->operand_count() == 1 &&
3447+
IsGlobalAllReduceWithSum(top_all_reduce)) {
3448+
return ReplaceInstruction(all_reduce, top_all_reduce);
3449+
}
3450+
return Status::OK();
3451+
}
3452+
34163453
Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) {
34173454
HloReduceInstruction* reduce = Cast<HloReduceInstruction>(hlo);
34183455
bool multi_output_reduce = reduce->shape().IsTuple();

tensorflow/compiler/plugin/poplar/tests/poplar_algebraic_simplifier_test.cc

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5779,6 +5779,29 @@ ENTRY main {
57795779
EXPECT_FALSE(PoplarAlgebraicSimplifier().Run(m.get()).ValueOrDie());
57805780
}
57815781

5782+
TEST_F(PoplarAlgebraicSimplifierTest, SimplifyAllReduceNormaliseAllReduce) {
5783+
const char* kModuleStr = R"(
5784+
HloModule m
5785+
sum {
5786+
y = f32[] parameter(1)
5787+
x = f32[] parameter(0), control-predecessors={y}
5788+
ROOT add = f32[] add(x, y), backend_config="{\"isInplace\":true}"
5789+
}
5790+
5791+
ENTRY main {
5792+
arg0 = f32[1000] parameter(0)
5793+
all-reduce0 = f32[1000] all-reduce(arg0), to_apply=sum
5794+
normalise = f32[1000] custom-call(all-reduce0), custom_call_target="ReplicationNormalise"
5795+
ROOT all-reduce1 = f32[1000] all-reduce(normalise), to_apply=sum
5796+
}
5797+
)";
5798+
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
5799+
ASSERT_TRUE(CustomOpReplacer().Run(m.get()).ValueOrDie());
5800+
ASSERT_TRUE(PoplarAlgebraicSimplifier().Run(m.get()).ValueOrDie());
5801+
EXPECT_THAT(m->entry_computation()->root_instruction(),
5802+
GmockMatch(m::AllReduce(m::Parameter(0))));
5803+
}
5804+
57825805
} // namespace
57835806
} // namespace poplarplugin
57845807
} // namespace xla

0 commit comments

Comments
 (0)