@@ -146,7 +146,7 @@ ENTRY main {
146146 EXPECT_NE (plan1, plan2);
147147}
148148
149- TEST_F (SlicePlanTest, ShareSliceAndUpdatePlan ) {
149+ TEST_F (SlicePlanTest, ShareSliceAndUpdateAddPlan ) {
150150 const string& hlo_string = R"(
151151HloModule main
152152
@@ -181,7 +181,7 @@ ENTRY main {
181181 EXPECT_EQ (plan1, plan2);
182182}
183183
184- TEST_F (SlicePlanTest, ShareMultipleSliceAndUpdatePlan ) {
184+ TEST_F (SlicePlanTest, ShareMultipleSliceAndUpdateAddPlan ) {
185185 const string& hlo_string = R"(
186186HloModule main
187187
@@ -225,7 +225,7 @@ ENTRY main {
225225 EXPECT_EQ (plan1, plan3);
226226}
227227
228- TEST_F (SlicePlanTest, DontShareSliceAndUpdatePlan ) {
228+ TEST_F (SlicePlanTest, DontShareSliceAndUpdateAddPlan ) {
229229 const string& hlo_string = R"(
230230HloModule main
231231
@@ -264,6 +264,120 @@ ENTRY main {
264264 EXPECT_EQ (plan1, plan2);
265265 EXPECT_NE (plan1, plan3);
266266}
267+
268+ TEST_F (SlicePlanTest, ShareSliceAndUpdatePlan) {
269+ const string& hlo_string = R"(
270+ HloModule main
271+
272+ ENTRY main {
273+ input = f32[100,16] parameter(0)
274+ offsets = s32[24,1] parameter(1)
275+ slice = f32[24,16] custom-call(input, offsets), custom_call_target="MultiSlice"
276+ one = f32[] constant(1)
277+ big_one = f32[24,16] broadcast(one), dimensions={}
278+ slice_modified = f32[24,16] add(slice, big_one)
279+ update = f32[100,16] custom-call(input, offsets, slice_modified), custom_call_target="MultiUpdate", backend_config="{\"index_vector_dim\":1,\"update_dim\":1}\n"
280+ ROOT t = (f32[24,16], f32[100,16]) tuple(slice, update)
281+ }
282+ )" ;
283+ std::unique_ptr<HloModule> module =
284+ ParseAndReturnVerifiedModule (hlo_string).ConsumeValueOrDie ();
285+ auto resources = GetMockResources (module .get (), false );
286+ HloPassPipeline pipeline = GetMockPipeline (*resources.get ());
287+ EXPECT_TRUE (pipeline.Run (module .get ()).ValueOrDie ());
288+ TF_EXPECT_OK (
289+ EmbeddingPlansPreplanning (*resources).Run (module .get ()).status ());
290+ auto entry_computation = module ->entry_computation ();
291+ EntryVisitor visitor (*resources.get (), entry_computation);
292+ TF_EXPECT_OK (entry_computation->Accept (&visitor));
293+
294+ auto root = entry_computation->root_instruction ();
295+ auto slice = root->operand (0 );
296+ auto update = root->operand (1 );
297+ TF_ASSERT_OK_AND_ASSIGN (auto plan1, GetSlicePlan (*resources, slice));
298+ TF_ASSERT_OK_AND_ASSIGN (auto plan2, GetSlicePlan (*resources, update));
299+ EXPECT_EQ (plan1, plan2);
300+ }
301+
302+ TEST_F (SlicePlanTest, ShareMultipleSliceAndUpdatePlan) {
303+ const string& hlo_string = R"(
304+ HloModule main
305+
306+ ENTRY main {
307+ input = f32[100,16] parameter(0)
308+ offsets1 = s32[24,1] parameter(1)
309+ offsets2 = s32[12,1] parameter(2)
310+ slice1 = f32[24,16] custom-call(input, offsets1), custom_call_target="MultiSlice"
311+ slice2 = f32[12,16] custom-call(input, offsets2), custom_call_target="MultiSlice"
312+ one = f32[] constant(1)
313+ big_one1 = f32[24,16] broadcast(one), dimensions={}
314+ slice1_modified = f32[24,16] add(slice1, big_one1)
315+ big_one2 = f32[12,16] broadcast(one), dimensions={}
316+ slice2_modified = f32[12,16] add(slice2, big_one2)
317+ concat_offsets = s32[36,1] concatenate(offsets1, offsets2), dimensions={0}
318+ concat_updates = f32[36,16] concatenate(slice1_modified, slice2_modified), dimensions={0}
319+ update = f32[100,16] custom-call(input, concat_offsets, concat_updates), custom_call_target="MultiUpdate", backend_config="{\"index_vector_dim\":1,\"update_dim\":1}\n"
320+ ROOT t = (f32[24,16], f32[12,16], f32[100,16]) tuple(slice1, slice2, update)
321+ }
322+ )" ;
323+ std::unique_ptr<HloModule> module =
324+ ParseAndReturnVerifiedModule (hlo_string).ConsumeValueOrDie ();
325+ auto resources = GetMockResources (module .get (), false );
326+ HloPassPipeline pipeline = GetMockPipeline (*resources.get ());
327+ EXPECT_TRUE (pipeline.Run (module .get ()).ValueOrDie ());
328+ TF_EXPECT_OK (
329+ EmbeddingPlansPreplanning (*resources).Run (module .get ()).status ());
330+ auto entry_computation = module ->entry_computation ();
331+ EntryVisitor visitor (*resources.get (), entry_computation);
332+ TF_EXPECT_OK (entry_computation->Accept (&visitor));
333+
334+ auto root = entry_computation->root_instruction ();
335+ auto slice = root->operand (0 );
336+ auto update = root->operand (1 );
337+ TF_ASSERT_OK_AND_ASSIGN (auto plan1, GetSlicePlan (*resources, slice));
338+ TF_ASSERT_OK_AND_ASSIGN (auto plan2, GetSlicePlan (*resources, update));
339+ EXPECT_EQ (plan1, plan2);
340+ }
341+
342+ TEST_F (SlicePlanTest, DontShareSliceAndUpdatePlan) {
343+ const string& hlo_string = R"(
344+ HloModule main
345+
346+ ENTRY main {
347+ input = f32[100,16] parameter(0)
348+ offsets1 = s32[24,1] parameter(1)
349+ offsets2 = s32[12,1] parameter(2)
350+ slice1 = f32[24,16] custom-call(input, offsets1), custom_call_target="MultiSlice"
351+ slice2 = f32[12,16] custom-call(input, offsets2), custom_call_target="MultiSlice"
352+ one = f32[] constant(1)
353+ big_one1 = f32[24,16] broadcast(one), dimensions={}
354+ slice1_modified = f32[24,16] add(slice1, big_one1)
355+ update = f32[100,16] custom-call(input, offsets1, slice1_modified), custom_call_target="MultiUpdate", backend_config="{\"index_vector_dim\":1,\"update_dim\":1}\n"
356+ ROOT t = (f32[24,16], f32[12,16], f32[100,16]) tuple(slice1, slice2, update)
357+ }
358+ )" ;
359+ std::unique_ptr<HloModule> module =
360+ ParseAndReturnVerifiedModule (hlo_string).ConsumeValueOrDie ();
361+ auto resources = GetMockResources (module .get (), false );
362+ HloPassPipeline pipeline = GetMockPipeline (*resources.get ());
363+ EXPECT_TRUE (pipeline.Run (module .get ()).ValueOrDie ());
364+ TF_EXPECT_OK (
365+ EmbeddingPlansPreplanning (*resources).Run (module .get ()).status ());
366+ auto entry_computation = module ->entry_computation ();
367+ EntryVisitor visitor (*resources.get (), entry_computation);
368+ TF_EXPECT_OK (entry_computation->Accept (&visitor));
369+
370+ auto root = entry_computation->root_instruction ();
371+ auto slice1 = root->operand (0 );
372+ auto slice2 = root->operand (1 );
373+ auto update = root->operand (2 );
374+ TF_ASSERT_OK_AND_ASSIGN (auto plan1, GetSlicePlan (*resources, slice1));
375+ TF_ASSERT_OK_AND_ASSIGN (auto plan2, GetSlicePlan (*resources, slice2));
376+ TF_ASSERT_OK_AND_ASSIGN (auto plan3, GetSlicePlan (*resources, update));
377+ EXPECT_EQ (plan1, plan2);
378+ EXPECT_NE (plan1, plan3);
379+ }
380+
267381} // namespace
268382} // namespace poplarplugin
269383} // namespace xla
0 commit comments