@@ -2121,6 +2121,153 @@ TEST_F(PoplarAlgebraicSimplifierTest, SimplifyConcatenateOfSlices) {
21212121 EXPECT_EQ (computation->root_instruction ()->operand (3 )->slice_starts (1 ), 40 );
21222122}
21232123
2124+ TEST_F (PoplarAlgebraicSimplifierTest, SimplifyConcatenateOfSameSlices_1_2_2_1) {
2125+ auto m = CreateNewVerifiedModule ();
2126+ Shape r2f32 = ShapeUtil::MakeShape (F32, {3 , 4 });
2127+ Shape concat_shape = ShapeUtil::MakeShape (F32, {3 , 6 });
2128+ HloComputation::Builder builder (TestName ());
2129+ HloInstruction* param0 = builder.AddInstruction (
2130+ HloInstruction::CreateParameter (0 , r2f32, " param0" ));
2131+
2132+ HloInstruction* slice0 = builder.AddInstruction (HloInstruction::CreateSlice (
2133+ ShapeUtil::MakeShape (F32, {3 , 1 }), param0, /* start_indices=*/ {0 , 0 },
2134+ /* limit_indices=*/ {3 , 1 }, /* strides=*/ {1 , 1 }));
2135+
2136+ HloInstruction* slice1 = builder.AddInstruction (HloInstruction::CreateSlice (
2137+ ShapeUtil::MakeShape (F32, {3 , 1 }), param0, /* start_indices=*/ {0 , 1 },
2138+ /* limit_indices=*/ {3 , 2 }, /* strides=*/ {1 , 1 }));
2139+
2140+ HloInstruction* slice2 = builder.AddInstruction (HloInstruction::CreateSlice (
2141+ ShapeUtil::MakeShape (F32, {3 , 1 }), param0, /* start_indices=*/ {0 , 2 },
2142+ /* limit_indices=*/ {3 , 3 }, /* strides=*/ {1 , 1 }));
2143+
2144+ HloInstruction* slice3 = builder.AddInstruction (HloInstruction::CreateSlice (
2145+ ShapeUtil::MakeShape (F32, {3 , 1 }), param0, /* start_indices=*/ {0 , 3 },
2146+ /* limit_indices=*/ {3 , 4 }, /* strides=*/ {1 , 1 }));
2147+
2148+ builder.AddInstruction (HloInstruction::CreateConcatenate (
2149+ concat_shape, {slice0, slice1, slice1, slice2, slice2, slice3}, 1 ));
2150+ auto computation = m->AddEntryComputation (builder.Build ());
2151+
2152+ PoplarAlgebraicSimplifier simplifier;
2153+ ASSERT_TRUE (simplifier.Run (m.get ()).ValueOrDie ());
2154+
2155+ auto root = computation->root_instruction ();
2156+ EXPECT_THAT (root, GmockMatch (m::Concatenate (
2157+ m::Slice (m::Parameter (0 )),
2158+ m::Reshape (m::Broadcast (m::Slice (m::Parameter (0 )))),
2159+ m::Reshape (m::Broadcast (m::Slice (m::Parameter (0 )))),
2160+ m::Slice (m::Parameter (0 )))));
2161+ // Check shape of broadcast
2162+ EXPECT_TRUE (ShapeUtil::Equal (root->operand (1 )->operand (0 )->shape (),
2163+ ShapeUtil::MakeShape (F32, {3 , 1 , 2 })));
2164+ // check slices start/limit indices
2165+ EXPECT_THAT (root->operand (0 )->slice_starts (), ElementsAre (0 , 0 ));
2166+ EXPECT_THAT (root->operand (0 )->slice_limits (), ElementsAre (3 , 1 ));
2167+ EXPECT_THAT (root->operand (1 )->operand (0 )->operand (0 )->slice_starts (),
2168+ ElementsAre (0 , 1 ));
2169+ EXPECT_THAT (root->operand (1 )->operand (0 )->operand (0 )->slice_limits (),
2170+ ElementsAre (3 , 2 ));
2171+ EXPECT_THAT (root->operand (2 )->operand (0 )->operand (0 )->slice_starts (),
2172+ ElementsAre (0 , 2 ));
2173+ EXPECT_THAT (root->operand (2 )->operand (0 )->operand (0 )->slice_limits (),
2174+ ElementsAre (3 , 3 ));
2175+ EXPECT_THAT (root->operand (3 )->slice_starts (), ElementsAre (0 , 3 ));
2176+ EXPECT_THAT (root->operand (3 )->slice_limits (), ElementsAre (3 , 4 ));
2177+ }
2178+
2179+ TEST_F (PoplarAlgebraicSimplifierTest, SimplifyConcatenateOfSameSlices_5_5_5_5) {
2180+ auto m = CreateNewVerifiedModule ();
2181+ Shape r2f32 = ShapeUtil::MakeShape (F32, {1 , 4 });
2182+ Shape concat_shape = ShapeUtil::MakeShape (F32, {1 , 20 });
2183+ HloComputation::Builder builder (TestName ());
2184+ HloInstruction* param0 = builder.AddInstruction (
2185+ HloInstruction::CreateParameter (0 , r2f32, " param0" ));
2186+
2187+ HloInstruction* slice0 = builder.AddInstruction (HloInstruction::CreateSlice (
2188+ ShapeUtil::MakeShape (F32, {1 , 1 }), param0, /* start_indices=*/ {0 , 0 },
2189+ /* limit_indices=*/ {1 , 1 }, /* strides=*/ {1 , 1 }));
2190+
2191+ HloInstruction* slice1 = builder.AddInstruction (HloInstruction::CreateSlice (
2192+ ShapeUtil::MakeShape (F32, {1 , 1 }), param0, /* start_indices=*/ {0 , 1 },
2193+ /* limit_indices=*/ {1 , 2 }, /* strides=*/ {1 , 1 }));
2194+
2195+ HloInstruction* slice2 = builder.AddInstruction (HloInstruction::CreateSlice (
2196+ ShapeUtil::MakeShape (F32, {1 , 1 }), param0, /* start_indices=*/ {0 , 2 },
2197+ /* limit_indices=*/ {1 , 3 }, /* strides=*/ {1 , 1 }));
2198+
2199+ HloInstruction* slice3 = builder.AddInstruction (HloInstruction::CreateSlice (
2200+ ShapeUtil::MakeShape (F32, {1 , 1 }), param0, /* start_indices=*/ {0 , 3 },
2201+ /* limit_indices=*/ {1 , 4 }, /* strides=*/ {1 , 1 }));
2202+
2203+ builder.AddInstruction (HloInstruction::CreateConcatenate (
2204+ concat_shape, {slice0, slice0, slice0, slice0, slice0, slice1, slice1,
2205+ slice1, slice1, slice1, slice2, slice2, slice2, slice2,
2206+ slice2, slice3, slice3, slice3, slice3, slice3},
2207+ 1 ));
2208+ auto computation = m->AddEntryComputation (builder.Build ());
2209+
2210+ PoplarAlgebraicSimplifier simplifier;
2211+ ASSERT_TRUE (simplifier.Run (m.get ()).ValueOrDie ());
2212+
2213+ auto root = computation->root_instruction ();
2214+ EXPECT_THAT (root, GmockMatch (m::Concatenate (
2215+ m::Reshape (m::Broadcast (m::Slice (m::Parameter (0 )))),
2216+ m::Reshape (m::Broadcast (m::Slice (m::Parameter (0 )))),
2217+ m::Reshape (m::Broadcast (m::Slice (m::Parameter (0 )))),
2218+ m::Reshape (m::Broadcast (m::Slice (m::Parameter (0 )))))));
2219+ EXPECT_TRUE (ShapeUtil::Equal (root->operand (0 )->operand (0 )->shape (),
2220+ ShapeUtil::MakeShape (F32, {1 , 1 , 5 })));
2221+ EXPECT_THAT (root->operand (0 )->operand (0 )->operand (0 )->slice_starts (),
2222+ ElementsAre (0 , 0 ));
2223+ EXPECT_THAT (root->operand (0 )->operand (0 )->operand (0 )->slice_limits (),
2224+ ElementsAre (1 , 1 ));
2225+ EXPECT_THAT (root->operand (1 )->operand (0 )->operand (0 )->slice_starts (),
2226+ ElementsAre (0 , 1 ));
2227+ EXPECT_THAT (root->operand (1 )->operand (0 )->operand (0 )->slice_limits (),
2228+ ElementsAre (1 , 2 ));
2229+ EXPECT_THAT (root->operand (2 )->operand (0 )->operand (0 )->slice_starts (),
2230+ ElementsAre (0 , 2 ));
2231+ EXPECT_THAT (root->operand (2 )->operand (0 )->operand (0 )->slice_limits (),
2232+ ElementsAre (1 , 3 ));
2233+ EXPECT_THAT (root->operand (3 )->operand (0 )->operand (0 )->slice_starts (),
2234+ ElementsAre (0 , 3 ));
2235+ EXPECT_THAT (root->operand (3 )->operand (0 )->operand (0 )->slice_limits (),
2236+ ElementsAre (1 , 4 ));
2237+ }
2238+
2239+ TEST_F (PoplarAlgebraicSimplifierTest, SimplifyConcatenateOfSameSlices_1x5) {
2240+ auto m = CreateNewVerifiedModule ();
2241+ Shape r2f32 = ShapeUtil::MakeShape (F32, {1 , 4 });
2242+ Shape concat_shape = ShapeUtil::MakeShape (F32, {1 , 5 });
2243+ HloComputation::Builder builder (TestName ());
2244+ HloInstruction* param0 = builder.AddInstruction (
2245+ HloInstruction::CreateParameter (0 , r2f32, " param0" ));
2246+
2247+ HloInstruction* slice0 = builder.AddInstruction (HloInstruction::CreateSlice (
2248+ ShapeUtil::MakeShape (F32, {1 , 1 }), param0, /* start_indices=*/ {0 , 0 },
2249+ /* limit_indices=*/ {1 , 1 }, /* strides=*/ {1 , 1 }));
2250+
2251+ builder.AddInstruction (HloInstruction::CreateConcatenate (
2252+ concat_shape, {slice0, slice0, slice0, slice0, slice0}, 1 ));
2253+ auto computation = m->AddEntryComputation (builder.Build ());
2254+
2255+ PoplarAlgebraicSimplifier simplifier;
2256+ ASSERT_TRUE (simplifier.Run (m.get ()).ValueOrDie ());
2257+
2258+ auto root = computation->root_instruction ();
2259+
2260+ // There shouldn't be the concatenate if the inputs to the concatenate are all
2261+ // from the same slice
2262+ EXPECT_THAT (root,
2263+ GmockMatch (m::Reshape (m::Broadcast (m::Slice (m::Parameter (0 ))))));
2264+
2265+ EXPECT_TRUE (ShapeUtil::Equal (root->operand (0 )->shape (),
2266+ ShapeUtil::MakeShape (F32, {1 , 1 , 5 })));
2267+ EXPECT_THAT (root->operand (0 )->operand (0 )->slice_starts (), ElementsAre (0 , 0 ));
2268+ EXPECT_THAT (root->operand (0 )->operand (0 )->slice_limits (), ElementsAre (1 , 1 ));
2269+ }
2270+
21242271// Test transforming reshapes and transposes of rng.
21252272TEST_F (PoplarAlgebraicSimplifierTest, ReshapeOfTransposeOfRngToRng) {
21262273 auto m = CreateNewVerifiedModule ();
0 commit comments