@@ -171,6 +171,20 @@ StatusOr<HloInstruction*> PreserveFrontendAttributesIfNeeded(
171171 return new_inst;
172172}
173173
174+ bool IsGlobalAllReduceWithSum (const HloInstruction* all_reduce) {
175+ if (all_reduce->opcode () != HloOpcode::kAllReduce ||
176+ !all_reduce->replica_groups ().empty ()) {
177+ return false ;
178+ }
179+ auto & called_computations = all_reduce->called_computations ();
180+ if (called_computations.size () != 1 ) {
181+ return false ;
182+ }
183+ const HloComputation* comp = called_computations.front ();
184+ return Match (comp->root_instruction (),
185+ m::Add (m::Parameter (), m::Parameter ()));
186+ }
187+
174188// AlgebraicSimplifierVisitor traverses the HLO computation and reduces certain
175189// algebraic expressions to simplified forms. Note: This only supports
176190// simplifications that simply look at the operands of an instruction. For the
@@ -183,6 +197,8 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor {
183197
184198 Status HandleAdd (HloInstruction* add) override ;
185199
200+ Status HandleAllReduce (HloInstruction* all_reduce) override ;
201+
186202 Status HandleAnd (HloInstruction* logical_and) override ;
187203
188204 Status HandleBitcast (HloInstruction* bitcast) override ;
@@ -3413,6 +3429,27 @@ Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice(
34133429 return Status::OK ();
34143430}
34153431
3432+ Status AlgebraicSimplifierVisitor::HandleAllReduce (HloInstruction* all_reduce) {
3433+ // / Replace all-reduce(replication-normalise(all-reduce(arg))) with
3434+ // / all-reduce(arg)
3435+ if (all_reduce->operand_count () != 1 ||
3436+ !IsGlobalAllReduceWithSum (all_reduce)) {
3437+ return Status::OK ();
3438+ }
3439+ HloInstruction* normalise = all_reduce->mutable_operand (0 );
3440+ if (!pp::IsPoplarInstruction (PoplarOp::ReplicationNormalise, normalise) ||
3441+ normalise->operand_count () != 1 ) {
3442+ return Status::OK ();
3443+ }
3444+ HloInstruction* top_all_reduce = normalise->mutable_operand (0 );
3445+ if (top_all_reduce->opcode () == HloOpcode::kAllReduce &&
3446+ top_all_reduce->operand_count () == 1 &&
3447+ IsGlobalAllReduceWithSum (top_all_reduce)) {
3448+ return ReplaceInstruction (all_reduce, top_all_reduce);
3449+ }
3450+ return Status::OK ();
3451+ }
3452+
34163453Status AlgebraicSimplifierVisitor::HandleReduce (HloInstruction* hlo) {
34173454 HloReduceInstruction* reduce = Cast<HloReduceInstruction>(hlo);
34183455 bool multi_output_reduce = reduce->shape ().IsTuple ();
0 commit comments