Skip to content

Commit fccb3bf

Browse files
committed
Improved error message: verify that "distributed_batch_norm_replica_group_size" divides replication factor (TF2)
Summary: D73331 cherry-picked for TF2 Reviewers: #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, gauthamg Reviewed By: #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, gauthamg Maniphest Tasks: T57992 Differential Revision: https://phabricator.sourcevertex.net/D73609
1 parent 748c31f commit fccb3bf

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

tensorflow/compiler/plugin/poplar/driver/poplar_compiler.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1624,6 +1624,8 @@ StatusOr<std::unique_ptr<PoplarExecutableCore>> CompileEngine(
16241624

16251625
const auto num_local_ipus = poplar_executor->GetNumIpusInLocalProcess(target);
16261626
const auto local_replication_factor = num_local_ipus / num_shards;
1627+
const auto replica_group_size =
1628+
poplar_executor->ExperimentalDistributedBatchNormReplicaGroupSize();
16271629

16281630
if (num_local_ipus % num_shards) {
16291631
return xla::InternalErrorStrCat(
@@ -1632,6 +1634,14 @@ StatusOr<std::unique_ptr<PoplarExecutableCore>> CompileEngine(
16321634
" The number of shards needs to divide the number of local IPUs.");
16331635
}
16341636

1637+
if (replica_group_size && replication_factor % replica_group_size) {
1638+
return xla::InternalErrorStrCat(
1639+
"The number of replicas (", replication_factor,
1640+
") must be divisible by",
1641+
" distributed_batch_norm_replica_group_size (", replica_group_size,
1642+
").");
1643+
}
1644+
16351645
// Currently we only support performing replica partitioning across the local
16361646
// replicas in each process, as this allows access to all the parts of a
16371647
// partitioned remote buffer locally. This means that copying to/from all the

0 commit comments

Comments
 (0)