@@ -485,6 +485,28 @@ def fun(float(N,K) A, float(K,M) B, float(N,M) C) -> (O) {
485485 EXPECT_TRUE (cDeclPos == std::string::npos)
486486 << " tensor C promoted to register but has no reuse" ;
487487 }
488+
489+ void expectFourOElementsPromoted (const std::string& code) {
490+ auto oDeclPos = code.find (" float32 _O_0[4][1];" );
491+ EXPECT_TRUE (oDeclPos != std::string::npos)
492+ << " expected O to be promoted to registers" ;
493+
494+ expectNoABCPromotion (code);
495+
496+ auto o00Pos = code.find (" _O_0[0][0]" );
497+ auto o10Pos = code.find (" _O_0[1][0]" );
498+ auto o20Pos = code.find (" _O_0[2][0]" );
499+ auto o30Pos = code.find (" _O_0[3][0]" );
500+
501+ EXPECT_TRUE (o00Pos != std::string::npos)
502+ << " expected constant subscripts in _O_0" ;
503+ EXPECT_TRUE (o10Pos != std::string::npos)
504+ << " expected constant subscripts in _O_0" ;
505+ EXPECT_TRUE (o20Pos != std::string::npos)
506+ << " expected constant subscripts in _O_0" ;
507+ EXPECT_TRUE (o30Pos != std::string::npos)
508+ << " expected constant subscripts in _O_0" ;
509+ }
488510};
489511
490512TEST_F (MatMulBias, RegisterPromotion) {
@@ -544,25 +566,7 @@ TEST_F(MatMulBias, RegistersAtRoot) {
544566
545567 // Expecting 4 elements because we map the loop i in O[i][j] to 8 threads
546568 // after tiling by 32.
547- auto oDeclPos = code.find (" float32 _O_0[4][1];" );
548- EXPECT_TRUE (oDeclPos != std::string::npos)
549- << " expected O to be promoted to registers" ;
550-
551- expectNoABCPromotion (code);
552-
553- auto o00Pos = code.find (" _O_0[0][0]" );
554- auto o10Pos = code.find (" _O_0[1][0]" );
555- auto o20Pos = code.find (" _O_0[2][0]" );
556- auto o30Pos = code.find (" _O_0[3][0]" );
557-
558- EXPECT_TRUE (o00Pos != std::string::npos)
559- << " expected constant subscripts in _O_0" ;
560- EXPECT_TRUE (o10Pos != std::string::npos)
561- << " expected constant subscripts in _O_0" ;
562- EXPECT_TRUE (o20Pos != std::string::npos)
563- << " expected constant subscripts in _O_0" ;
564- EXPECT_TRUE (o30Pos != std::string::npos)
565- << " expected constant subscripts in _O_0" ;
569+ expectFourOElementsPromoted (code);
566570}
567571
568572TEST_F (MatMulBias, RegistersAtRootNotEnoughUnroll) {
@@ -589,23 +593,24 @@ TEST_F(MatMulBias, RegistersBelowFirstBand) {
589593 using namespace polyhedral ::detail;
590594
591595 // Disable automatic promotion to registers because we are going to call it
592- // manually.
596+ // manually. Use a large unroll size to unroll all loops below the first
597+ // band and actually hit registers.
593598 auto mappingOptions = CudaMappingOptions::makeNaiveMappingOptions ()
599+ .unroll (512 )
594600 .useSharedMemory (false )
595601 .usePrivateMemory (false );
596602 auto mscop = prepare ({{" N" , 42 }, {" M" , 56 }, {" K" , 37 }}, mappingOptions);
597603
598- auto nodes = ScheduleTree::collectDFSPostorder (
604+ auto nodes = ScheduleTree::collectDFSPreorder (
599605 mscop->scop ().scheduleRoot (), ScheduleTreeType::Band);
600606 ASSERT_GT (nodes.size (), 0u );
601607 auto node = nodes[0 ];
602608 promoteToRegistersBelow (*mscop, node);
603609 auto code = emitCode (mscop);
604610
605- auto oDeclPos = code.find (" float32 _O_0[1][1];" );
606- EXPECT_TRUE (oDeclPos != std::string::npos)
607- << " expected O to be promoted to registers" ;
608- expectNoABCPromotion (code);
611+ // Expecting 4 elements because we map the loop i in O[i][j] to 8 threads
612+ // after tiling by 32.
613+ expectFourOElementsPromoted (code);
609614}
610615
611616class Strided : public TestMapper {
0 commit comments