Skip to content

Commit 0f62148

Browse files
committed
Improve error message for unsupported RTS config
Summary: This reverts commit df5bdd3. Test Plan: revert-hammer Reviewers: Subscribers:
1 parent e9b0081 commit 0f62148

File tree

4 files changed

+137
-9
lines changed

4 files changed

+137
-9
lines changed

tensorflow/compiler/plugin/poplar/driver/passes/replicated_resource_update_elementwise_clustering.cc

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,19 @@ ReplicatedResourceUpdateElementwiseClustering::AddClusterInput(
280280
// replica 3: 0 1 1 1 1 2 3 4
281281
// ------- ------- -------
282282

283+
if (ipu_link_domain_replication_factor_ < global_replication_factor_ &&
284+
partition_replication_factor_ < ipu_link_domain_replication_factor_) {
285+
// TODO(T45704): GCL does not support this and will report an error, but
286+
// the message might be hard to understand for the user. So we catch it
287+
// here to attempt to report a more understandable error message.
288+
return Unimplemented(
289+
"Replicated partitioning is not supported when there are multiple "
290+
"instances per IPU-link domain. The number of local replicas per "
291+
"partitioning domain (%u) is less than the number of replicas "
292+
"per IPU-link domain (%u).",
293+
partition_replication_factor_, ipu_link_domain_replication_factor_);
294+
}
295+
283296
CHECK_NE(partition_replication_factor_, 0);
284297
CHECK_EQ(global_replication_factor_ % partition_replication_factor_, 0);
285298
const uint64 orthogonal_group_size =
@@ -549,7 +562,7 @@ ReplicatedResourceUpdateElementwiseClustering::UpdateClusterBackendConfig(
549562
cluster, backend_config));
550563
auto* call_config = backend_config.mutable_call_config();
551564
auto* function_config = call_config->mutable_function_config();
552-
// Setting parititoned_elementwise_cluster attribute indicates that we will
565+
// Setting partitioned_elementwise_cluster attribute indicates that we will
553566
// process those clusters differently:
554567
// - Remote buffer outlining pass will outline load/stores regardless if it's
555568
// unique or not.

tensorflow/compiler/plugin/poplar/driver/passes/replicated_resource_update_elementwise_clustering.h

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,19 @@ class ReplicatedResourceUpdateElementwiseClustering final
3333
public:
3434
ReplicatedResourceUpdateElementwiseClustering(
3535
CompilerAnnotations& annotations, uint32 partition_replication_factor,
36-
uint32 global_replication_factor)
36+
uint32 global_replication_factor,
37+
uint32 ipu_link_domain_replication_factor)
3738
: annotations_(annotations),
3839
partition_replication_factor_(partition_replication_factor),
39-
global_replication_factor_(global_replication_factor) {}
40+
global_replication_factor_(global_replication_factor),
41+
ipu_link_domain_replication_factor_(
42+
ipu_link_domain_replication_factor) {}
4043

4144
explicit ReplicatedResourceUpdateElementwiseClustering(
4245
CompilerAnnotations& annotations, uint32 replication_factor)
4346
: ReplicatedResourceUpdateElementwiseClustering(
44-
annotations, replication_factor, replication_factor) {}
47+
annotations, replication_factor, replication_factor,
48+
replication_factor) {}
4549

4650
absl::string_view name() const override {
4751
return "replicated-resource-update-elementwise-clustering";
@@ -86,6 +90,7 @@ class ReplicatedResourceUpdateElementwiseClustering final
8690
CompilerAnnotations& annotations_;
8791
uint32 partition_replication_factor_;
8892
uint32 global_replication_factor_;
93+
uint32 ipu_link_domain_replication_factor_;
8994
};
9095

9196
} // namespace poplarplugin

tensorflow/compiler/plugin/poplar/driver/poplar_compiler.cc

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1228,17 +1228,35 @@ StatusOr<std::unique_ptr<PoplarExecutableCore>> CompileEngine(
12281228
" The number of shards needs to divide the number of local IPUs.");
12291229
}
12301230

1231-
CHECK_LE(local_replication_factor, replication_factor);
1232-
12331231
// Currently we only support performing replica partitioning across the local
12341232
// replicas in each process, as this allows access to all the parts of a
12351233
// partitioned remote buffer locally. This means that copying to/from all the
12361234
// parts of the partitioned remote buffer can be done without any additional
12371235
// inter-process collective communication.
12381236
const auto partition_replication_factor = local_replication_factor;
12391237

1238+
// The IPU-link domain size is the number of IPUs per IPU-link domain.
1239+
// A CPU target cannot be trusted to report a sensible value.
1240+
const auto num_ipu_link_domain_ipus =
1241+
target.getTargetType() == poplar::TargetType::CPU
1242+
? num_ipus
1243+
: std::min(num_ipus, target.getIpuLinkDomainSize());
1244+
CHECK_GE(num_ipu_link_domain_ipus, num_shards);
1245+
CHECK_EQ(num_ipu_link_domain_ipus % num_shards, 0);
1246+
const auto ipu_link_domain_replication_factor =
1247+
num_ipu_link_domain_ipus / num_shards;
1248+
1249+
// Replication factor invariant: local <= ipu_link_domain <= global.
1250+
CHECK_LE(local_replication_factor, ipu_link_domain_replication_factor);
1251+
CHECK_LE(ipu_link_domain_replication_factor, replication_factor);
1252+
1253+
CHECK_EQ(replication_factor % ipu_link_domain_replication_factor, 0);
1254+
CHECK_EQ(ipu_link_domain_replication_factor % local_replication_factor, 0);
1255+
12401256
VLOG(1) << "Local replication factor " << local_replication_factor
12411257
<< ", global replication factor " << replication_factor
1258+
<< ", IPU-link domain replication factor "
1259+
<< ipu_link_domain_replication_factor
12421260
<< ", partition replication factor " << partition_replication_factor;
12431261

12441262
auto config = module->config();
@@ -1459,7 +1477,7 @@ StatusOr<std::unique_ptr<PoplarExecutableCore>> CompileEngine(
14591477
pipeline.AddPass<PipelineFIFOInserter>(resources.remote_memory_supported);
14601478
pipeline.AddPass<ReplicatedResourceUpdateElementwiseClustering>(
14611479
resources.annotations, resources.partition_replication_factor,
1462-
resources.replication_factor);
1480+
resources.replication_factor, ipu_link_domain_replication_factor);
14631481
{
14641482
auto inline_fusion = [](const HloInstruction* inst) {
14651483
return IsReplicatedParameterLoadFusion(inst) ||

tensorflow/compiler/plugin/poplar/tests/replicated_resource_update_elementwise_clustering_test.cc

Lines changed: 94 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1436,7 +1436,8 @@ TEST_F(TestPartitionReplicationFactor, TestCollectiveGroups) {
14361436
EXPECT_TRUE(offloaded);
14371437

14381438
ReplicatedResourceUpdateElementwiseClustering pass(
1439-
annotations, partition_replication_factor, global_replication_factor);
1439+
annotations, partition_replication_factor, global_replication_factor,
1440+
global_replication_factor);
14401441
auto elementwise_comps =
14411442
ElementwiseCluster::GetElementwiseClusterableComputations(module.get());
14421443
TF_ASSERT_OK_AND_ASSIGN(auto clusters,
@@ -1497,6 +1498,96 @@ TEST_F(TestPartitionReplicationFactor, TestCollectiveGroups) {
14971498
partition_replication_factor));
14981499
}
14991500

1501+
TEST_F(TestPartitionReplicationFactor, TestUnsupportedPartitioning) {
1502+
const std::string hlo = R"(
1503+
HloModule main
1504+
1505+
sum {
1506+
y = f16[] parameter(1)
1507+
x = f16[] parameter(0), control-predecessors={y}
1508+
ROOT add = f16[] add(x, y), backend_config="{\"isInplace\":true}"
1509+
}
1510+
1511+
resource_update {
1512+
arg0 = f16[128] parameter(0)
1513+
arg1 = f16[128] parameter(1)
1514+
arg2 = f16[128] parameter(2)
1515+
1516+
arg0_r = f16[128] all-reduce(arg0), to_apply=sum
1517+
arg2_new = f16[128] add(arg0_r, arg2)
1518+
arg1_new = f16[128] add(arg1, arg2_new)
1519+
1520+
ROOT t = (f16[128],f16[128]) tuple(arg1_new, arg2_new)
1521+
counter_0 = s32[] constant(4)
1522+
gac = () custom-call(s32[] counter_0), custom_call_target="GradientAccumulationCount"
1523+
}
1524+
1525+
loop {
1526+
after-all = token[] after-all()
1527+
infeed = (f16[128], token[]) infeed(after-all), infeed_config="140121807314576"
1528+
input = f16[128] get-tuple-element(infeed), index=0
1529+
1530+
arg0 = f16[128] parameter(0)
1531+
arg1 = f16[128] parameter(1)
1532+
1533+
add.1 = f16[128] add(input, arg0)
1534+
call = (f16[128],f16[128]) call(add.1, arg0, arg1), to_apply=resource_update, frontend_attributes={CALL_CONFIG_TYPE="ResourceUpdate"}, backend_config="{\"callConfig\":{\"type\":\"ResourceUpdate\",\"resourceUpdateConfig\":{\"offloadVariables\":\"THREESTATE_ON\", \"partitionOffloadedVariables\":\"THREESTATE_ON\"}}}"
1535+
gte0 = f16[128] get-tuple-element(call), index=0
1536+
gte1 = f16[128] get-tuple-element(call), index=1
1537+
ROOT r = (f16[128],f16[128]) tuple(gte0, gte1)
1538+
}
1539+
1540+
ENTRY e {
1541+
e.in0 = f16[128] parameter(0)
1542+
e.in1 = f16[128] parameter(1)
1543+
loop_call = (f16[128],f16[128]) call(e.in0, e.in1), to_apply=loop, backend_config="{\"callConfig\":{\"type\":\"RepeatLoop\",\"repeatConfig\":{\"repeatCount\":\"100\"}}}"
1544+
gte0 = f16[128] get-tuple-element(loop_call), index=0
1545+
gte1 = f16[128] get-tuple-element(loop_call), index=1
1546+
ROOT r = (f16[128],f16[128]) tuple(gte0, gte1)
1547+
}
1548+
)";
1549+
1550+
auto config = GetModuleConfigForTest();
1551+
config.set_argument_input_indices({});
1552+
config.set_resource_input_indices({0, 1});
1553+
config.set_resource_input_initialized({true, true});
1554+
config.set_resource_update_to_input_index({0, 1});
1555+
TF_ASSERT_OK_AND_ASSIGN(auto module,
1556+
ParseAndReturnVerifiedModule(hlo, config));
1557+
1558+
HloInstruction* loop =
1559+
CHECK_NOTNULL(FindInstruction(module.get(), "loop_call"));
1560+
1561+
const uint64 partition_replication_factor = 2;
1562+
const uint64 ipu_link_domain_replication_factor = 4;
1563+
const uint64 global_replication_factor = 8;
1564+
1565+
CompilerAnnotations annotations(module.get());
1566+
TF_ASSERT_OK_AND_ASSIGN(
1567+
bool offloaded,
1568+
VariablesOffloadAndPartition(
1569+
annotations, /*remote_memory_supported=*/true,
1570+
/*minimum_remote_tensor_size=*/4, partition_replication_factor)
1571+
.Run(module.get()));
1572+
EXPECT_TRUE(offloaded);
1573+
1574+
ReplicatedResourceUpdateElementwiseClustering pass(
1575+
annotations, partition_replication_factor, global_replication_factor,
1576+
ipu_link_domain_replication_factor);
1577+
auto elementwise_comps =
1578+
ElementwiseCluster::GetElementwiseClusterableComputations(module.get());
1579+
TF_ASSERT_OK_AND_ASSIGN(auto clusters,
1580+
pass.GetClustersIn(loop, elementwise_comps));
1581+
ASSERT_THAT(clusters.size(), 1);
1582+
auto& cluster = clusters.front();
1583+
const auto status = pass.OutlineCluster(cluster).status();
1584+
ASSERT_THAT(status.code(), tensorflow::error::UNIMPLEMENTED);
1585+
EXPECT_THAT(status.error_message(),
1586+
::testing::StartsWith(
1587+
"Replicated partitioning is not supported when there are "
1588+
"multiple instances per IPU-link domain."));
1589+
}
1590+
15001591
TEST_F(TestPartitionReplicationFactor, TestNonGlobalAllReduce) {
15011592
const std::string hlo = R"(
15021593
HloModule main
@@ -1569,7 +1660,8 @@ TEST_F(TestPartitionReplicationFactor, TestNonGlobalAllReduce) {
15691660
EXPECT_TRUE(offloaded);
15701661

15711662
ReplicatedResourceUpdateElementwiseClustering pass(
1572-
annotations, partition_replication_factor, global_replication_factor);
1663+
annotations, partition_replication_factor, global_replication_factor,
1664+
global_replication_factor);
15731665
auto elementwise_comps =
15741666
ElementwiseCluster::GetElementwiseClusterableComputations(module.get());
15751667
TF_ASSERT_OK_AND_ASSIGN(auto clusters,

0 commit comments

Comments
 (0)