Skip to content

Commit de7e74f

Browse files
committed
Fix selection of inputs to clusters
Summary: Ref T40083 Reviewers: vladimirm, hakons, #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved Reviewed By: vladimirm, hakons, #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved Maniphest Tasks: T40083 Differential Revision: https://phabricator.sourcevertex.net/D53213
1 parent 07ff34e commit de7e74f

File tree

2 files changed

+8
-26
lines changed

2 files changed

+8
-26
lines changed

tensorflow/compiler/plugin/poplar/driver/tools/elementwise_cluster.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -130,19 +130,19 @@ ElementwiseClusterValidator::Inputs ElementwiseClusterValidator::GetValidInputs(
130130
Inputs valid_inputs;
131131
for (const HloInstruction* inst : comp->MakeInstructionPostOrder()) {
132132
bool valid_input = false;
133-
if (IsParameter(inst) && parameter_filter(inst->parameter_number())) {
133+
if (IsParameter(inst)) {
134+
valid_input = parameter_filter(inst->parameter_number());
135+
} else if (IsGlobalAllReduce(inst)) {
134136
valid_input = true;
137+
} else if (IsBroadcast(inst)) {
138+
valid_input =
139+
IsScalar(inst->operand(0)) && valid_inputs.contains(inst->operand(0));
135140
} else if (!IsPoplarInstruction(PoplarOp::ExecutionCounter, inst) &&
136141
!inst->HasSideEffect()) {
137142
valid_input = absl::c_all_of(
138143
inst->operands(), [&valid_inputs](const HloInstruction* operand) {
139144
return valid_inputs.contains(operand);
140145
});
141-
} else if (IsBroadcast(inst)) {
142-
valid_input =
143-
IsScalar(inst->operand(0)) && valid_inputs.contains(inst->operand(0));
144-
} else if (IsGlobalAllReduce(inst)) {
145-
valid_input = true;
146146
}
147147

148148
if (valid_input) {

tensorflow/compiler/plugin/poplar/tests/replicated_resource_update_elementwise_clustering_test.cc

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1666,26 +1666,8 @@ TEST_F(TestPartitionReplicationFactor, TestNonGlobalAllReduce) {
16661666
ElementwiseCluster::GetElementwiseClusterableComputations(module.get());
16671667
TF_ASSERT_OK_AND_ASSIGN(auto clusters,
16681668
pass.GetClustersIn(loop, elementwise_comps));
1669-
ASSERT_THAT(clusters.size(), 1);
1670-
1671-
const int64 shard_size = 128 / partition_replication_factor;
1672-
auto& cluster = *std::begin(clusters);
1673-
EXPECT_THAT(cluster.GetClusterSize(), 128);
1674-
EXPECT_THAT(cluster.GetAlignedClusterSize(), 128);
1675-
EXPECT_THAT(cluster.GetShardSize(), shard_size);
1676-
1677-
// Convert the cluster.
1678-
TF_ASSERT_OK(pass.OutlineCluster(cluster).status());
1679-
TF_ASSERT_OK_AND_ASSIGN(bool eliminated, HloDCE().Run(module.get()));
1680-
1681-
EXPECT_THAT(arg0->user_count(), 1);
1682-
HloInstruction* all_reduce = GetNextUser(arg0).ValueOrDie();
1683-
1684-
// The non-global all-reduce should be left alone.
1685-
EXPECT_THAT(all_reduce->opcode(), HloOpcode::kAllReduce);
1686-
EXPECT_THAT(all_reduce->replica_groups().size(), 1);
1687-
EXPECT_THAT(all_reduce->replica_groups()[0].replica_ids(),
1688-
::testing::ElementsAre(0));
1669+
// No clusters because the all-reduce is not global.
1670+
ASSERT_THAT(clusters.size(), 0);
16891671
}
16901672

16911673
TEST_F(TestPartitionReplicationFactor, IgnoreImplicit2ScalarBroadcast) {

0 commit comments

Comments
 (0)