@@ -622,6 +622,9 @@ void promoteToRegistersBelowThreads(
622622 // do not correspond to band members that should be fixed to obtain
623623 // per-thread-group access relations.
624624 auto points = activeDomainPoints (root, band);
625+ auto partialSched = partialSchedule (root, band);
626+ auto activeStmts = activeStatements (root, band);
627+
625628 size_t nMappedThreads = 0 ;
626629 for (int j = 0 ; j < points.dim (isl::dim_type::param); ++j) {
627630 auto id = points.get_space ().get_dim_id (isl::dim_type::param, j);
@@ -639,12 +642,12 @@ void promoteToRegistersBelowThreads(
639642 }
640643
641644 auto groupMap = TensorReferenceGroup::accessedBySubtree (band, scop);
642- for (const auto & tensorGroups : groupMap) {
645+ for (auto & tensorGroups : groupMap) {
643646 auto tensorId = tensorGroups.first ;
644647
645648 // TODO: sorting of groups and counting the number of promoted elements
646649
647- for (const auto & group : tensorGroups.second ) {
650+ for (auto & group : tensorGroups.second ) {
648651 auto sizes = group->approximationSizes ();
649652 // No point in promoting a scalar that will go to a register anyway.
650653 if (sizes.size () == 0 ) {
@@ -664,6 +667,13 @@ void promoteToRegistersBelowThreads(
664667 // TODO: if something is already in shared, but reuse it within one
665668 // thread only, there is no point in keeping it in shared _if_ it
666669 // gets promoted into a register.
670+ scop.promoteGroup (
671+ Scop::PromotedDecl::Kind::Register,
672+ tensorId,
673+ std::move (group),
674+ band,
675+ activeStmts,
676+ partialSched);
667677 }
668678 }
669679 }
0 commit comments