Skip to content

Commit 8e24843

Browse files
samgizFrederik Mellbye
authored andcommitted
Restrict reduction preapply cases where the reduction is the RHS of a SUB or a DIVIDE
Summary: The elementwise preapply pass was being applied for reduce operations, when the reduce operation is the RHS operator of a subtract or divide operation, even though that is incorrect. This commit makes the pass not match such cases. REF T68013 Test Plan: CI passes, additional tests for this case. Reviewers: #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, alfiee, gauthamg Reviewed By: #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, gauthamg Maniphest Tasks: T68013 Differential Revision: https://phabricator.sourcevertex.net/D74620
1 parent e77e6d5 commit 8e24843

File tree

2 files changed

+32
-0
lines changed

2 files changed

+32
-0
lines changed

tensorflow/compiler/plugin/poplar/driver/passes/elementwise_preapply.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,11 +254,17 @@ StatusOr<bool> HandleReduce(HloInstruction* inst, HloInstruction* elementwise) {
254254
if (elementwise->opcode() != HloOpcode::kAdd &&
255255
elementwise->opcode() != HloOpcode::kSubtract)
256256
return false;
257+
if (elementwise->opcode() == HloOpcode::kSubtract &&
258+
inst != elementwise->operand(0))
259+
return false;
257260
break;
258261
case HloOpcode::kMultiply:
259262
if (elementwise->opcode() != HloOpcode::kMultiply &&
260263
elementwise->opcode() != HloOpcode::kDivide)
261264
return false;
265+
if (elementwise->opcode() == HloOpcode::kDivide &&
266+
inst != elementwise->operand(0))
267+
return false;
262268
break;
263269
case HloOpcode::kMaximum:
264270
case HloOpcode::kMinimum:

tensorflow/compiler/plugin/poplar/tests/elementwise_preapply_test.cc

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -754,6 +754,30 @@ static std::string hlo_binary_float() {
754754
)";
755755
}
756756

757+
// Same as hlo_binary_float, but the arguments to output are swapped.
758+
// When $FUNC2 is subtract or divide, this should not match the pattern
759+
// even if hlo_binary_float() would.
760+
static std::string hlo_binary_float_swapped() {
761+
return R"(
762+
HloModule module
763+
764+
function_to_apply {
765+
p0 = f32[] parameter(0)
766+
p1 = f32[] parameter(1)
767+
ROOT output = f32[] $FUNC1(p0, p1)
768+
}
769+
770+
ENTRY f {
771+
reduce_param = f32[2, 3] constant({{1, 2, 3}, {4, 5, 6}})
772+
reduce_init = f32[] constant(5)
773+
reduce = f32[3] reduce(reduce_param, reduce_init), dimensions={0}, to_apply=function_to_apply
774+
c = f32[] constant(2)
775+
second_elementwise_arg = f32[3] broadcast(c), dimensions={}
776+
ROOT output = f32[3] $FUNC2(second_elementwise_arg, reduce)
777+
}
778+
)";
779+
}
780+
757781
static std::string hlo_binary_bool() {
758782
return absl::StrReplaceAll(
759783
hlo_binary_float(),
@@ -822,6 +846,8 @@ INSTANTIATE_TEST_SUITE_P(
822846
{"divide", "multiply", hlo_binary_float()},
823847
{"xor", "xor", hlo_binary_bool()},
824848
{"minimum", "copy", hlo_unary_float()},
849+
{"add", "subtract", hlo_binary_float_swapped()},
850+
{"multiply", "divide", hlo_binary_float_swapped()},
825851
// scalar result (no broadcast for constant)
826852
{"minimum", "maximum", hlo_scalar_result()},
827853
{"add", "multiply", hlo_scalar_result()},

0 commit comments

Comments
 (0)