Skip to content

Commit 4673d3a

Browse files
BabakkGraphcoregeorgepaw
authored andcommitted
TF2 - Improve ShardingPass handling of after-all tokens
Summary: Updating handling of `after-all` instructions so they explicitly get given default sharding and are not considered for copying to users. This does 2 things, it stops infeeds/outfeeds using the sharding of their after-all input and it stops `ProcessComputation` using an after-all instruction as a means of kick-starting the sharding process when it can no longer make progress - an after-all doesn't help propagate sharding information since its consumers get their sharding from elsewhere, so the next ProcessComputation call still doesn't make any progress. By using an after-all to kick start things we potentially cause a tuple instruction to be prematurely given default sharding. TF2.5 Only TF1 - D64932 Test Plan: CI + new tests Reviewers: #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, samuelh, vladimirm Reviewed By: #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, samuelh, vladimirm Subscribers: vladimirm, samuelh Maniphest Tasks: T59245 Differential Revision: https://phabricator.sourcevertex.net/D64354
1 parent d324b5f commit 4673d3a

File tree

3 files changed

+170
-2
lines changed

3 files changed

+170
-2
lines changed

tensorflow/compiler/plugin/poplar/driver/passes/sharding_pass.cc

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,10 @@ bool CopyTupleShardingFromOperands(HloInstruction* inst) {
284284
bool CopyShardingFromOperands(HloInstruction* inst) {
285285
for (int o = 0; o < inst->operand_count(); o++) {
286286
auto* operand = inst->operand(o);
287-
if (operand->has_sharding()) {
287+
// We don't want to propagate the sharding of AfterAll tokens. The
288+
// sharding of in/outfeeds should come from their users/outfed data
289+
// respectively.
290+
if (operand->has_sharding() && operand->opcode() != HloOpcode::kAfterAll) {
288291
if (CompatibleShapes(inst->shape(), operand->shape())) {
289292
auto s = GetShardingOfOutputTensor(operand);
290293
SetSharding(inst, s);
@@ -381,9 +384,15 @@ StatusOr<bool> ProcessComputation(HloComputation* comp, int attempt) {
381384
done = true;
382385
bool made_progress = false;
383386
for (auto* inst : comp->MakeInstructionPostOrder()) {
384-
VLOG(3) << "Sharding pass visting instruction " << inst->name();
387+
VLOG(3) << "Attempt " << attempt << ": sharding pass visting instruction "
388+
<< inst->name();
385389
bool added_sharding = false;
386390

391+
if (!inst->has_sharding() && inst->opcode() == HloOpcode::kAfterAll) {
392+
SetSharding(inst, GetDefaultSharding(inst->shape()));
393+
added_sharding = true;
394+
}
395+
387396
// If an instruction has no operands, and no users but the root Tuple,
388397
// then assign default sharding
389398
if (!inst->has_sharding() && inst->operand_count() == 0 &&
@@ -433,6 +442,11 @@ StatusOr<bool> ProcessComputation(HloComputation* comp, int attempt) {
433442
// These are dealt with by the computation level code
434443
break;
435444
}
445+
case HloOpcode::kInfeed: {
446+
// Infeeds should get their sharding from their users
447+
// not their operands.
448+
break;
449+
}
436450
default: {
437451
if (IsPoplarInstruction(PoplarOp::Barrier, inst)) {
438452
added_sharding = CopyTupleShardingFromOperands(inst);

tensorflow/compiler/plugin/poplar/tests/sharding_pass_test.cc

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -845,6 +845,102 @@ main {
845845
EXPECT_EQ(shardings[1].GetUniqueDevice(), 1);
846846
}
847847

848+
TEST_F(ShardingPassTest, TestInfeedsWithPartiallySpecifiedSharding) {
849+
// This tests that we can infer a sharding for an infeed when not
850+
// all of the tuple elements returned by the infeed have been sharded.
851+
// In this case we're feeding in two values (f16[4,64,64,3], s32[4])
852+
// but are only sharding the first one. We expect s32[4]
853+
// (get-tuple-element.34) to be assigned default sharding which will let
854+
// the infeed sharding be resolved.
855+
const std::string hlo_string = R"(
856+
HloModule top
857+
858+
main {
859+
after-all.4 = token[] after-all()
860+
inf1 = ((f16[4,64,64,3], s32[4]), token[]) infeed(after-all.4), infeed_config="\022\0011\"\002\023\003(\003"
861+
get-tuple-element.30 = (f16[4,64,64,3], s32[4]) get-tuple-element(inf1), index=0
862+
get-tuple-element.31 = f16[4,64,64,3] get-tuple-element(get-tuple-element.30), index=0
863+
arg_1 = f16[3,64] parameter(0)
864+
dot.4 = f16[4,64,64,64] dot(get-tuple-element.31, arg_1), lhs_contracting_dims={3}, rhs_contracting_dims={0}, sharding={maximal device=1}
865+
arg_2 = f16[64,64] parameter(1)
866+
dot.5 = f16[4,64,64,64] dot(dot.4, arg_2), lhs_contracting_dims={3}, rhs_contracting_dims={0}, sharding={maximal device=1}
867+
get-tuple-element.34 = s32[4] get-tuple-element(get-tuple-element.30), index=1
868+
ROOT tuple.8 = (f16[4,64,64,64], s32[4]) tuple(dot.5, get-tuple-element.34)
869+
}
870+
)";
871+
HloModuleConfig config;
872+
config.set_debug_options(GetDebugOptionsForTest());
873+
874+
auto module_or_status = ParseAndReturnVerifiedModule(hlo_string, config);
875+
EXPECT_TRUE(module_or_status.ok());
876+
877+
auto* module = module_or_status.ValueOrDie().get();
878+
auto* comp = module->entry_computation();
879+
880+
ShardingPass shardingPass;
881+
ASSERT_TRUE(shardingPass.Run(module).ValueOrDie());
882+
883+
auto* inf1 = comp->GetInstructionWithName("inf1");
884+
ASSERT_TRUE(inf1->has_sharding());
885+
ASSERT_TRUE(inf1->sharding().IsTuple());
886+
auto shardings = inf1->sharding().tuple_elements();
887+
ASSERT_TRUE(shardings[0].HasUniqueDevice());
888+
ASSERT_EQ(shardings[0].GetUniqueDevice(), 1);
889+
ASSERT_TRUE(shardings[1].HasUniqueDevice());
890+
ASSERT_EQ(shardings[1].GetUniqueDevice(), 0);
891+
}
892+
893+
TEST_F(ShardingPassTest, TestOutfeedsDontTakeTokenSharding) {
894+
// This tests that outfeeds take sharding from their outfed data
895+
// input, not from their token.
896+
const std::string hlo_string = R"(
897+
HloModule top
898+
899+
main {
900+
arg_1 = f16[3,64] parameter(1)
901+
arg_2 = f16[4,64,64,3] parameter(2)
902+
dot.4 = f16[4,64,64,64] dot(arg_2, arg_1), lhs_contracting_dims={3}, rhs_contracting_dims={0}, sharding={maximal device=0}
903+
arg_3 = f16[64,64] parameter(3)
904+
dot.5 = f16[4,64,64,64] dot(dot.4, arg_3), lhs_contracting_dims={3}, rhs_contracting_dims={0}, sharding={maximal device=1}
905+
arg_4 = s32[4] parameter(4)
906+
outfed_tuple = (f16[4,64,64,64], s32[4]) tuple(dot.5, arg_4)
907+
after-all.5 = token[] after-all()
908+
outfeed = token[] outfeed(outfed_tuple, after-all.5), outfeed_config="\022\0012\"\002\023\003(\003"
909+
arg_0 = s32[] parameter(0)
910+
ROOT tuple.10 = (s32[], f16[3,64], f16[64,64]) tuple(arg_0, arg_1, arg_3)
911+
}
912+
)";
913+
HloModuleConfig config;
914+
config.set_debug_options(GetDebugOptionsForTest());
915+
916+
auto module_or_status = ParseAndReturnVerifiedModule(hlo_string, config);
917+
EXPECT_TRUE(module_or_status.ok());
918+
919+
auto* module = module_or_status.ValueOrDie().get();
920+
auto* comp = module->entry_computation();
921+
922+
ShardingPass shardingPass;
923+
ASSERT_TRUE(shardingPass.Run(module).ValueOrDie());
924+
925+
auto* outfed_tuple = comp->GetInstructionWithName("outfed_tuple");
926+
ASSERT_TRUE(outfed_tuple->has_sharding());
927+
const auto outfed_tuple_sharding = outfed_tuple->sharding();
928+
929+
auto* outfeed = comp->GetInstructionWithName("outfeed");
930+
ASSERT_TRUE(outfeed->has_sharding());
931+
932+
ASSERT_EQ(outfeed->sharding(), outfed_tuple_sharding);
933+
934+
ASSERT_TRUE(outfed_tuple_sharding.IsTuple());
935+
auto shardings = outfed_tuple_sharding.tuple_elements();
936+
937+
ASSERT_EQ(shardings.size(), 2);
938+
ASSERT_TRUE(shardings[0].HasUniqueDevice());
939+
ASSERT_EQ(shardings[0].GetUniqueDevice(), 1);
940+
ASSERT_TRUE(shardings[1].HasUniqueDevice());
941+
ASSERT_EQ(shardings[1].GetUniqueDevice(), 0);
942+
}
943+
848944
TEST_F(ShardingPassTest, TestGteOpsMatchTheirOperands) {
849945
std::string hlo_string = R"(
850946
HloModule top

tensorflow/python/ipu/tests/infeed_outfeed_test.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1475,6 +1475,64 @@ def my_net():
14751475
self.assertAllEqual(np.full([1, 10, 10], 0), out[6])
14761476
self.assertAllEqual(np.full([1, 10, 10], 0), out[7])
14771477

1478+
@test_util.deprecated_graph_mode_only
1479+
def testValidSharding(self):
1480+
# Reproducer from T59245
1481+
ipu_options = ipu.config.IPUConfig()
1482+
ipu_options.auto_select_ipus = 2
1483+
1484+
ipu_options.configure_ipu_system()
1485+
ipu.utils.move_variable_initialization_to_cpu()
1486+
outfeed_queue = ipu.ipu_outfeed_queue.IPUOutfeedQueue()
1487+
1488+
images = dataset_ops.Dataset.from_tensors(
1489+
np.ones((64, 64, 3), dtype=np.float16))
1490+
labels = dataset_ops.Dataset.from_tensors(np.ones((), dtype=np.int32))
1491+
tf_dataset = dataset_ops.Dataset.zip(
1492+
(images, labels)).repeat(4).batch(4, drop_remainder=True).repeat()
1493+
1494+
with ops.device('cpu'):
1495+
infeed_queue = ipu.ipu_infeed_queue.IPUInfeedQueue(tf_dataset)
1496+
1497+
def retinanet_validating_loop():
1498+
def body(images, imgIds):
1499+
with variable_scope.variable_scope("MainGraph"):
1500+
with ipu.scopes.ipu_shard(0):
1501+
w1 = variable_scope.get_variable(
1502+
"w1",
1503+
shape=[3, 64],
1504+
dtype=np.float16,
1505+
initializer=init_ops.glorot_uniform_initializer(
1506+
dtype=np.float16))
1507+
y = math_ops.matmul(images, w1)
1508+
1509+
with ipu.scopes.ipu_shard(1):
1510+
w2 = variable_scope.get_variable(
1511+
"w2",
1512+
shape=[64, 64],
1513+
dtype=np.float16,
1514+
initializer=init_ops.glorot_uniform_initializer(
1515+
dtype=np.float16))
1516+
scores = math_ops.matmul(y, w2)
1517+
1518+
out = outfeed_queue.enqueue([scores, imgIds])
1519+
return out
1520+
1521+
return ipu.loops.repeat(128, body, inputs=[], infeed_queue=infeed_queue)
1522+
1523+
with ipu.scopes.ipu_scope('/device:IPU:0'):
1524+
retinanet_validation_step = ipu.ipu_compiler.compile(
1525+
retinanet_validating_loop, inputs=[])
1526+
1527+
session = session_lib.Session()
1528+
session.run(infeed_queue.initializer)
1529+
session.run(variables.global_variables_initializer())
1530+
try:
1531+
# This can throw if the sharding of body is incorrect.
1532+
session.run(retinanet_validation_step)
1533+
except Exception as e: # pylint: disable=broad-except
1534+
self.fail(f"Unexpected exception thrown: {e}")
1535+
14781536
@test_util.run_v2_only
14791537
def testDeduceDevice(self):
14801538
cfg = ipu.config.IPUConfig()

0 commit comments

Comments
 (0)