Skip to content

Commit b4c53f0

Browse files
committed
Create slice plans for MultiUpdate
Summary: `MultiUpdate` now accepts non-empty plans as valid, so we can now generate them and pass them to the op. Test Plan: Added three new tests which borrow from the existing slice plan tests in `slice_plan_test.cc`: * `ShareSliceAndUpdatePlan` * `ShareMultipleSliceAndUpdatePlan` * `DontShareSliceAndUpdatePlan` Reviewers: georgep, samuelh, #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved Reviewed By: georgep, #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved Maniphest Tasks: T46140 Differential Revision: https://phabricator.sourcevertex.net/D52382
1 parent ea3b734 commit b4c53f0

File tree

2 files changed

+121
-18
lines changed

2 files changed

+121
-18
lines changed

tensorflow/compiler/plugin/poplar/driver/poplar_passes/embedding_plans_preplanning.cc

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -118,18 +118,6 @@ StatusOr<SlicePlanMap> GetSlicePlans(const InputToSliceUsersMap& user_map,
118118
return result;
119119
}
120120

121-
// TODO MultiUpdate instructions do not currently support plans.
122-
StatusOr<SlicePlanMap> GetEmptyPlans(const InputToSliceUsersMap& user_map,
123-
CompilerResources& res) {
124-
SlicePlanMap result;
125-
for (auto& pair : user_map) {
126-
const HloInstruction* operand = pair.first;
127-
res.slice_plans.push_back(popops::SlicePlan());
128-
result[operand] = {&res.slice_plans.back(), {}};
129-
}
130-
return result;
131-
}
132-
133121
Status PopulateWithPlans(const InputToSliceUsersMap& user_map,
134122
const SlicePlanMap& plans, CompilerResources& res) {
135123
for (auto& pair : user_map) {
@@ -205,9 +193,10 @@ StatusOr<bool> EmbeddingPlansPreplanning::Run(HloModule* module) {
205193
SlicePlanMap multi_update_add_plans,
206194
GetSlicePlans(multi_update_adds, resources_, multi_slice_plans));
207195

208-
// Create empty plans for multi updates.
209-
TF_ASSIGN_OR_RETURN(SlicePlanMap multi_update_plans,
210-
GetEmptyPlans(multi_updates, resources_));
196+
// Same as above, but for multi-update.
197+
TF_ASSIGN_OR_RETURN(
198+
SlicePlanMap multi_update_plans,
199+
GetSlicePlans(multi_updates, resources_, multi_slice_plans));
211200

212201
// Populate the slice plans in the CompilerResources.
213202
TF_RETURN_IF_ERROR(

tensorflow/compiler/plugin/poplar/tests/slice_plan_test.cc

Lines changed: 117 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ ENTRY main {
146146
EXPECT_NE(plan1, plan2);
147147
}
148148

149-
TEST_F(SlicePlanTest, ShareSliceAndUpdatePlan) {
149+
TEST_F(SlicePlanTest, ShareSliceAndUpdateAddPlan) {
150150
const string& hlo_string = R"(
151151
HloModule main
152152
@@ -181,7 +181,7 @@ ENTRY main {
181181
EXPECT_EQ(plan1, plan2);
182182
}
183183

184-
TEST_F(SlicePlanTest, ShareMultipleSliceAndUpdatePlan) {
184+
TEST_F(SlicePlanTest, ShareMultipleSliceAndUpdateAddPlan) {
185185
const string& hlo_string = R"(
186186
HloModule main
187187
@@ -225,7 +225,7 @@ ENTRY main {
225225
EXPECT_EQ(plan1, plan3);
226226
}
227227

228-
TEST_F(SlicePlanTest, DontShareSliceAndUpdatePlan) {
228+
TEST_F(SlicePlanTest, DontShareSliceAndUpdateAddPlan) {
229229
const string& hlo_string = R"(
230230
HloModule main
231231
@@ -264,6 +264,120 @@ ENTRY main {
264264
EXPECT_EQ(plan1, plan2);
265265
EXPECT_NE(plan1, plan3);
266266
}
267+
268+
TEST_F(SlicePlanTest, ShareSliceAndUpdatePlan) {
269+
const string& hlo_string = R"(
270+
HloModule main
271+
272+
ENTRY main {
273+
input = f32[100,16] parameter(0)
274+
offsets = s32[24,1] parameter(1)
275+
slice = f32[24,16] custom-call(input, offsets), custom_call_target="MultiSlice"
276+
one = f32[] constant(1)
277+
big_one = f32[24,16] broadcast(one), dimensions={}
278+
slice_modified = f32[24,16] add(slice, big_one)
279+
update = f32[100,16] custom-call(input, offsets, slice_modified), custom_call_target="MultiUpdate", backend_config="{\"index_vector_dim\":1,\"update_dim\":1}\n"
280+
ROOT t = (f32[24,16], f32[100,16]) tuple(slice, update)
281+
}
282+
)";
283+
std::unique_ptr<HloModule> module =
284+
ParseAndReturnVerifiedModule(hlo_string).ConsumeValueOrDie();
285+
auto resources = GetMockResources(module.get(), false);
286+
HloPassPipeline pipeline = GetMockPipeline(*resources.get());
287+
EXPECT_TRUE(pipeline.Run(module.get()).ValueOrDie());
288+
TF_EXPECT_OK(
289+
EmbeddingPlansPreplanning(*resources).Run(module.get()).status());
290+
auto entry_computation = module->entry_computation();
291+
EntryVisitor visitor(*resources.get(), entry_computation);
292+
TF_EXPECT_OK(entry_computation->Accept(&visitor));
293+
294+
auto root = entry_computation->root_instruction();
295+
auto slice = root->operand(0);
296+
auto update = root->operand(1);
297+
TF_ASSERT_OK_AND_ASSIGN(auto plan1, GetSlicePlan(*resources, slice));
298+
TF_ASSERT_OK_AND_ASSIGN(auto plan2, GetSlicePlan(*resources, update));
299+
EXPECT_EQ(plan1, plan2);
300+
}
301+
302+
TEST_F(SlicePlanTest, ShareMultipleSliceAndUpdatePlan) {
303+
const string& hlo_string = R"(
304+
HloModule main
305+
306+
ENTRY main {
307+
input = f32[100,16] parameter(0)
308+
offsets1 = s32[24,1] parameter(1)
309+
offsets2 = s32[12,1] parameter(2)
310+
slice1 = f32[24,16] custom-call(input, offsets1), custom_call_target="MultiSlice"
311+
slice2 = f32[12,16] custom-call(input, offsets2), custom_call_target="MultiSlice"
312+
one = f32[] constant(1)
313+
big_one1 = f32[24,16] broadcast(one), dimensions={}
314+
slice1_modified = f32[24,16] add(slice1, big_one1)
315+
big_one2 = f32[12,16] broadcast(one), dimensions={}
316+
slice2_modified = f32[12,16] add(slice2, big_one2)
317+
concat_offsets = s32[36,1] concatenate(offsets1, offsets2), dimensions={0}
318+
concat_updates = f32[36,16] concatenate(slice1_modified, slice2_modified), dimensions={0}
319+
update = f32[100,16] custom-call(input, concat_offsets, concat_updates), custom_call_target="MultiUpdate", backend_config="{\"index_vector_dim\":1,\"update_dim\":1}\n"
320+
ROOT t = (f32[24,16], f32[12,16], f32[100,16]) tuple(slice1, slice2, update)
321+
}
322+
)";
323+
std::unique_ptr<HloModule> module =
324+
ParseAndReturnVerifiedModule(hlo_string).ConsumeValueOrDie();
325+
auto resources = GetMockResources(module.get(), false);
326+
HloPassPipeline pipeline = GetMockPipeline(*resources.get());
327+
EXPECT_TRUE(pipeline.Run(module.get()).ValueOrDie());
328+
TF_EXPECT_OK(
329+
EmbeddingPlansPreplanning(*resources).Run(module.get()).status());
330+
auto entry_computation = module->entry_computation();
331+
EntryVisitor visitor(*resources.get(), entry_computation);
332+
TF_EXPECT_OK(entry_computation->Accept(&visitor));
333+
334+
auto root = entry_computation->root_instruction();
335+
auto slice = root->operand(0);
336+
auto update = root->operand(1);
337+
TF_ASSERT_OK_AND_ASSIGN(auto plan1, GetSlicePlan(*resources, slice));
338+
TF_ASSERT_OK_AND_ASSIGN(auto plan2, GetSlicePlan(*resources, update));
339+
EXPECT_EQ(plan1, plan2);
340+
}
341+
342+
TEST_F(SlicePlanTest, DontShareSliceAndUpdatePlan) {
343+
const string& hlo_string = R"(
344+
HloModule main
345+
346+
ENTRY main {
347+
input = f32[100,16] parameter(0)
348+
offsets1 = s32[24,1] parameter(1)
349+
offsets2 = s32[12,1] parameter(2)
350+
slice1 = f32[24,16] custom-call(input, offsets1), custom_call_target="MultiSlice"
351+
slice2 = f32[12,16] custom-call(input, offsets2), custom_call_target="MultiSlice"
352+
one = f32[] constant(1)
353+
big_one1 = f32[24,16] broadcast(one), dimensions={}
354+
slice1_modified = f32[24,16] add(slice1, big_one1)
355+
update = f32[100,16] custom-call(input, offsets1, slice1_modified), custom_call_target="MultiUpdate", backend_config="{\"index_vector_dim\":1,\"update_dim\":1}\n"
356+
ROOT t = (f32[24,16], f32[12,16], f32[100,16]) tuple(slice1, slice2, update)
357+
}
358+
)";
359+
std::unique_ptr<HloModule> module =
360+
ParseAndReturnVerifiedModule(hlo_string).ConsumeValueOrDie();
361+
auto resources = GetMockResources(module.get(), false);
362+
HloPassPipeline pipeline = GetMockPipeline(*resources.get());
363+
EXPECT_TRUE(pipeline.Run(module.get()).ValueOrDie());
364+
TF_EXPECT_OK(
365+
EmbeddingPlansPreplanning(*resources).Run(module.get()).status());
366+
auto entry_computation = module->entry_computation();
367+
EntryVisitor visitor(*resources.get(), entry_computation);
368+
TF_EXPECT_OK(entry_computation->Accept(&visitor));
369+
370+
auto root = entry_computation->root_instruction();
371+
auto slice1 = root->operand(0);
372+
auto slice2 = root->operand(1);
373+
auto update = root->operand(2);
374+
TF_ASSERT_OK_AND_ASSIGN(auto plan1, GetSlicePlan(*resources, slice1));
375+
TF_ASSERT_OK_AND_ASSIGN(auto plan2, GetSlicePlan(*resources, slice2));
376+
TF_ASSERT_OK_AND_ASSIGN(auto plan3, GetSlicePlan(*resources, update));
377+
EXPECT_EQ(plan1, plan2);
378+
EXPECT_NE(plan1, plan3);
379+
}
380+
267381
} // namespace
268382
} // namespace poplarplugin
269383
} // namespace xla

0 commit comments

Comments
 (0)