@@ -1436,7 +1436,8 @@ TEST_F(TestPartitionReplicationFactor, TestCollectiveGroups) {
14361436 EXPECT_TRUE (offloaded);
14371437
14381438 ReplicatedResourceUpdateElementwiseClustering pass (
1439- annotations, partition_replication_factor, global_replication_factor);
1439+ annotations, partition_replication_factor, global_replication_factor,
1440+ global_replication_factor);
14401441 auto elementwise_comps =
14411442 ElementwiseCluster::GetElementwiseClusterableComputations (module .get ());
14421443 TF_ASSERT_OK_AND_ASSIGN (auto clusters,
@@ -1497,6 +1498,96 @@ TEST_F(TestPartitionReplicationFactor, TestCollectiveGroups) {
14971498 partition_replication_factor));
14981499}
14991500
1501+ TEST_F (TestPartitionReplicationFactor, TestUnsupportedPartitioning) {
1502+ const std::string hlo = R"(
1503+ HloModule main
1504+
1505+ sum {
1506+ y = f16[] parameter(1)
1507+ x = f16[] parameter(0), control-predecessors={y}
1508+ ROOT add = f16[] add(x, y), backend_config="{\"isInplace\":true}"
1509+ }
1510+
1511+ resource_update {
1512+ arg0 = f16[128] parameter(0)
1513+ arg1 = f16[128] parameter(1)
1514+ arg2 = f16[128] parameter(2)
1515+
1516+ arg0_r = f16[128] all-reduce(arg0), to_apply=sum
1517+ arg2_new = f16[128] add(arg0_r, arg2)
1518+ arg1_new = f16[128] add(arg1, arg2_new)
1519+
1520+ ROOT t = (f16[128],f16[128]) tuple(arg1_new, arg2_new)
1521+ counter_0 = s32[] constant(4)
1522+ gac = () custom-call(s32[] counter_0), custom_call_target="GradientAccumulationCount"
1523+ }
1524+
1525+ loop {
1526+ after-all = token[] after-all()
1527+ infeed = (f16[128], token[]) infeed(after-all), infeed_config="140121807314576"
1528+ input = f16[128] get-tuple-element(infeed), index=0
1529+
1530+ arg0 = f16[128] parameter(0)
1531+ arg1 = f16[128] parameter(1)
1532+
1533+ add.1 = f16[128] add(input, arg0)
1534+ call = (f16[128],f16[128]) call(add.1, arg0, arg1), to_apply=resource_update, frontend_attributes={CALL_CONFIG_TYPE="ResourceUpdate"}, backend_config="{\"callConfig\":{\"type\":\"ResourceUpdate\",\"resourceUpdateConfig\":{\"offloadVariables\":\"THREESTATE_ON\", \"partitionOffloadedVariables\":\"THREESTATE_ON\"}}}"
1535+ gte0 = f16[128] get-tuple-element(call), index=0
1536+ gte1 = f16[128] get-tuple-element(call), index=1
1537+ ROOT r = (f16[128],f16[128]) tuple(gte0, gte1)
1538+ }
1539+
1540+ ENTRY e {
1541+ e.in0 = f16[128] parameter(0)
1542+ e.in1 = f16[128] parameter(1)
1543+ loop_call = (f16[128],f16[128]) call(e.in0, e.in1), to_apply=loop, backend_config="{\"callConfig\":{\"type\":\"RepeatLoop\",\"repeatConfig\":{\"repeatCount\":\"100\"}}}"
1544+ gte0 = f16[128] get-tuple-element(loop_call), index=0
1545+ gte1 = f16[128] get-tuple-element(loop_call), index=1
1546+ ROOT r = (f16[128],f16[128]) tuple(gte0, gte1)
1547+ }
1548+ )" ;
1549+
1550+ auto config = GetModuleConfigForTest ();
1551+ config.set_argument_input_indices ({});
1552+ config.set_resource_input_indices ({0 , 1 });
1553+ config.set_resource_input_initialized ({true , true });
1554+ config.set_resource_update_to_input_index ({0 , 1 });
1555+ TF_ASSERT_OK_AND_ASSIGN (auto module ,
1556+ ParseAndReturnVerifiedModule (hlo, config));
1557+
1558+ HloInstruction* loop =
1559+ CHECK_NOTNULL (FindInstruction (module .get (), " loop_call" ));
1560+
1561+ const uint64 partition_replication_factor = 2 ;
1562+ const uint64 ipu_link_domain_replication_factor = 4 ;
1563+ const uint64 global_replication_factor = 8 ;
1564+
1565+ CompilerAnnotations annotations (module .get ());
1566+ TF_ASSERT_OK_AND_ASSIGN (
1567+ bool offloaded,
1568+ VariablesOffloadAndPartition (
1569+ annotations, /* remote_memory_supported=*/ true ,
1570+ /* minimum_remote_tensor_size=*/ 4 , partition_replication_factor)
1571+ .Run (module .get ()));
1572+ EXPECT_TRUE (offloaded);
1573+
1574+ ReplicatedResourceUpdateElementwiseClustering pass (
1575+ annotations, partition_replication_factor, global_replication_factor,
1576+ ipu_link_domain_replication_factor);
1577+ auto elementwise_comps =
1578+ ElementwiseCluster::GetElementwiseClusterableComputations (module .get ());
1579+ TF_ASSERT_OK_AND_ASSIGN (auto clusters,
1580+ pass.GetClustersIn (loop, elementwise_comps));
1581+ ASSERT_THAT (clusters.size (), 1 );
1582+ auto & cluster = clusters.front ();
1583+ const auto status = pass.OutlineCluster (cluster).status ();
1584+ ASSERT_THAT (status.code (), tensorflow::error::UNIMPLEMENTED);
1585+ EXPECT_THAT (status.error_message (),
1586+ ::testing::StartsWith (
1587+ " Replicated partitioning is not supported when there are "
1588+ " multiple instances per IPU-link domain." ));
1589+ }
1590+
15001591TEST_F (TestPartitionReplicationFactor, TestNonGlobalAllReduce) {
15011592 const std::string hlo = R"(
15021593 HloModule main
@@ -1569,7 +1660,8 @@ TEST_F(TestPartitionReplicationFactor, TestNonGlobalAllReduce) {
15691660 EXPECT_TRUE (offloaded);
15701661
15711662 ReplicatedResourceUpdateElementwiseClustering pass (
1572- annotations, partition_replication_factor, global_replication_factor);
1663+ annotations, partition_replication_factor, global_replication_factor,
1664+ global_replication_factor);
15731665 auto elementwise_comps =
15741666 ElementwiseCluster::GetElementwiseClusterableComputations (module .get ());
15751667 TF_ASSERT_OK_AND_ASSIGN (auto clusters,
0 commit comments