2727#include < algorithm>
2828#include < numeric>
2929#include < sstream>
30+ #include < type_traits>
3031
3132namespace tc {
3233namespace polyhedral {
@@ -128,6 +129,21 @@ void mapCopiesToThreads(MappedScop& mscop, bool unroll) {
128129 }
129130}
130131
132+ /*
133+ * Starting from the root, find all thread specific markers. Use
134+ * DFSPreorder to make sure order is specified and consistent for tests.
135+ */
136+ template <typename T>
137+ std::vector<T> findThreadSpecificMarkers (T root) {
138+ using namespace tc ::polyhedral::detail;
139+ static_assert (
140+ std::is_convertible<T, const ScheduleTree*>::value,
141+ " expecting ScheduleTree" );
142+
143+ return ScheduleTree::collectDFSPreorder (
144+ root, ScheduleTreeType::ThreadSpecificMarker);
145+ }
146+
131147/*
132148 * Transform schedule bands into a union_map.
133149 * Takes all partial schedules at leaves as MUPAs (without accounting for
@@ -555,51 +571,28 @@ void promoteGreedilyAtDepth(
555571 mapCopiesToThreads (mscop, unrollCopies);
556572}
557573
558- // Assuming the mapping to threads happens in inverse order, i.e. the innermost
559- // loop is mapped to thread x, promote below that depth.
560- void promoteToRegistersBelowThreads (
561- Scop& scop,
562- const ThreadIdxXScheduleDepthState& threadIdxXScheduleDepthState,
563- size_t nRegisters) {
574+ // Promote at the positions of the thread specific markers.
575+ void promoteToRegistersBelowThreads (Scop& scop, size_t nRegisters) {
564576 using namespace tc ::polyhedral::detail;
565577
566578 auto root = scop.scheduleRoot ();
567579
568580 auto fullSched = fullSchedule (root);
569- for (const auto & kvp : threadIdxXScheduleDepthState) {
570- auto depth = kvp.second + 1 ;
571- auto subdomain = kvp.first ;
572-
573- // Collect all bands where a member is located at the given depth.
574- auto bands = bandsContainingScheduleDepth (root, depth);
575- // We may have no band members mapped to thread x in case when we
576- // force-mapped everything to one thread.
577- if (bands.size () == 0 ) {
578- continue ;
579- }
580-
581- // Keep only those bands for which this depth was recorded.
582- std::function<bool (ScheduleTree*)> keepActive =
583- [root, subdomain](const ScheduleTree* tree) {
584- isl::union_set active = activeDomainPoints (root, tree);
585- return !active.intersect (subdomain).is_empty ();
586- };
587- bands = functional::Filter (keepActive, bands);
588-
589- // Make sure the band ends at thread x depth so we can promote below it.
590- bands = bandsSplitAfterDepth (bands, root, depth);
581+ {
582+ auto markers = findThreadSpecificMarkers (root);
591583
592- for (auto band : bands ) {
584+ for (auto marker : markers ) {
593585 // Find out how many threads are actually mapped. Active domain points
594586 // will involve all mapping parameters when we take them below the
595587 // mapping. Skip mapping parameters obviously mapped to 0, because they
596588 // do not correspond to band members that should be fixed to obtain
597589 // per-thread-group access relations.
598- auto points = activeDomainPoints (root, band );
599- auto partialSched = partialSchedule (root, band );
590+ auto points = activeDomainPoints (root, marker );
591+ auto partialSched = prefixSchedule (root, marker );
600592 // Pure affine schedule without (mapping) filters.
601- auto partialSchedMupa = partialScheduleMupa (root, band );
593+ auto partialSchedMupa = prefixScheduleMupa (root, marker );
602594
595+ auto depth = marker->scheduleDepth (root);
603596 size_t nMappedThreads = 0 ;
604597 for (unsigned j = 0 ; j < points.dim (isl::dim_type::param); ++j) {
605598 auto id = points.get_space ().get_dim_id (isl::dim_type::param, j);
@@ -616,7 +609,7 @@ void promoteToRegistersBelowThreads(
616609 }
617610 }
618611
619- auto groupMap = TensorReferenceGroup::accessedBySubtree (band , scop);
612+ auto groupMap = TensorReferenceGroup::accessedBySubtree (marker , scop);
620613 for (auto & tensorGroups : groupMap) {
621614 auto tensorId = tensorGroups.first ;
622615
@@ -642,7 +635,7 @@ void promoteToRegistersBelowThreads(
642635 Scop::PromotedDecl::Kind::Register,
643636 tensorId,
644637 std::move (group),
645- band ,
638+ marker ,
646639 partialSched);
647640 }
648641 }
0 commit comments