@@ -108,22 +108,69 @@ const CudaDim& mappingSize<mapping::ThreadId>(const MappedScop* mscop) {
108108}
109109} // namespace
110110
111+ // Map the elements in "list" to successive blocks or thread identifiers,
112+ // with the first element mapped to identifier X. The extents are obtained
113+ // from the initial elements of numBlocks or numThreads. The identifiers
114+ // must not be present in the space of the partial schedules in "list" and
115+ // extents must be non-zero. The mapping corresponds to inserting a filter
116+ // node with condition 'list % extent = ids'.
117+ // The mapping is inserted above "tree".
118+ //
119+ // Return a pointer to the updated node (below the inserted filter)
120+ // for call chaining purposes.
111121template <typename MappingTypeId>
112- void MappedScop::mapRemaining (detail::ScheduleTree* tree, size_t nMapped) {
113- size_t nToMap = mappingSize<MappingTypeId>(this ).view .size ();
114- if (nMapped >= nToMap) {
115- return ;
122+ detail::ScheduleTree* MappedScop::map (
123+ detail::ScheduleTree* tree,
124+ isl::union_pw_aff_list list) {
125+ size_t nToMap = list.n ();
126+ const auto & extent = mappingSize<MappingTypeId>(this ).view ;
127+ CHECK_LE (nToMap, extent.size ()) << " dimension overflow" ;
128+
129+ auto root = scop_->scheduleRoot ();
130+ auto domain = activeDomainPoints (root, tree).universe ();
131+ auto filter = domain;
132+
133+ std::unordered_set<MappingTypeId, typename MappingTypeId::Hash> idSet;
134+ for (size_t i = 0 ; i < nToMap; ++i) {
135+ auto id = MappingTypeId::makeId (i);
136+ auto upa = list.get (i);
137+ // Introduce the "mapping" parameter after checking it is not already
138+ // present in the schedule space.
139+ CHECK (not upa.involves_param (id));
140+ CHECK_NE (extent[i], 0u ) << " NYI: mapping to 0" ;
141+
142+ // Create mapping filter by equating the newly introduced
143+ // parameter ids[i] to the "i"-th affine function modulo its extent.
144+ upa = upa.mod_val (isl::val (tree->ctx_ , extent[i]));
145+ upa = upa.sub (isl::union_pw_aff::param_on_domain (domain, id));
146+ filter = filter.intersect (upa.zero_union_set ());
147+
148+ idSet.emplace (id);
116149 }
117150
118- std::unordered_set<MappingTypeId, typename MappingTypeId::Hash> ids;
119- for (size_t i = nMapped; i < nToMap; ++i) {
120- ids.insert (MappingTypeId::makeId (i));
151+ std::unordered_set<MappingTypeId, typename MappingTypeId::Hash> unmapped;
152+ for (size_t i = nToMap; i < extent.size (); ++i) {
153+ auto id = MappingTypeId::makeId (i);
154+ unmapped.emplace (id);
155+ idSet.emplace (id);
121156 }
122- auto root = scop_->scheduleRoot ();
123- auto domain = activeDomainPoints (root, tree);
124- auto filter = makeFixRemainingZeroFilter (domain, ids);
125- auto mapping = detail::ScheduleTree::makeMappingFilter (filter, ids);
126- insertNodeAbove (root, tree, std::move (mapping));
157+ filter = filter.intersect (makeFixRemainingZeroFilter (domain, unmapped));
158+
159+ auto mapping = detail::ScheduleTree::makeMappingFilter (filter, idSet);
160+ tree = insertNodeAbove (root, tree, std::move (mapping))->child ({0 });
161+
162+ return tree;
163+ }
164+
165+ detail::ScheduleTree* MappedScop::mapBlocksForward (
166+ detail::ScheduleTree* band,
167+ size_t nToMap) {
168+ auto bandNode = band->elemAs <detail::ScheduleTreeElemBand>();
169+ CHECK (bandNode) << " expected a band, got " << *band;
170+
171+ auto list = bandNode->mupa_ .get_union_pw_aff_list ();
172+ list = list.drop (nToMap, list.n () - nToMap);
173+ return map<mapping::BlockId>(band, list);
127174}
128175
129176// Uses as many blockSizes elements as outer coincident dimensions in the
@@ -142,10 +189,7 @@ void MappedScop::mapToBlocksAndScaleBand(
142189 // and no more than block dimensions to be mapped
143190 nBlocksToMap = std::min (nBlocksToMap, numBlocks.view .size ());
144191
145- for (size_t i = 0 ; i < nBlocksToMap; ++i) {
146- band = map (band, i, mapping::BlockId::makeId (i));
147- }
148- mapRemaining<mapping::BlockId>(band, nBlocksToMap);
192+ mapBlocksForward (band, nBlocksToMap);
149193 bandScale (band, tileSizes);
150194}
151195
@@ -166,10 +210,7 @@ void fixThreadsBelow(
166210
167211 auto band = detail::ScheduleTree::makeEmptyBand (mscop.scop ().scheduleRoot ());
168212 auto bandTree = insertNodeBelow (tree, std::move (band));
169- auto ctx = tree->ctx_ ;
170- insertNodeBelow (
171- bandTree, detail::ScheduleTree::makeThreadSpecificMarker (ctx));
172- mscop.mapRemaining <mapping::ThreadId>(bandTree, begin);
213+ mscop.mapThreadsBackward (bandTree);
173214}
174215
175216bool MappedScop::detectReductions (detail::ScheduleTree* tree) {
@@ -305,6 +346,22 @@ detail::ScheduleTree* MappedScop::separateReduction(detail::ScheduleTree* st) {
305346 return st->ancestor (root, 2 );
306347}
307348
349+ detail::ScheduleTree* MappedScop::mapThreadsBackward (
350+ detail::ScheduleTree* band) {
351+ auto bandNode = band->elemAs <detail::ScheduleTreeElemBand>();
352+ CHECK (bandNode);
353+ auto nMember = bandNode->nMember ();
354+ auto nToMap = std::min (nMember, numThreads.view .size ());
355+ CHECK_LE (nToMap, 3 ) << " mapping to too many threads" ;
356+
357+ auto ctx = band->ctx_ ;
358+ insertNodeBelow (band, detail::ScheduleTree::makeThreadSpecificMarker (ctx));
359+
360+ auto list = bandNode->mupa_ .get_union_pw_aff_list ().reverse ();
361+ list = list.drop (nToMap, list.n () - nToMap);
362+ return map<mapping::ThreadId>(band, list);
363+ }
364+
308365size_t MappedScop::mapToThreads (detail::ScheduleTree* band) {
309366 using namespace tc ::polyhedral::detail;
310367
@@ -355,20 +412,9 @@ size_t MappedScop::mapToThreads(detail::ScheduleTree* band) {
355412 bandSplit (scop_->scheduleRoot (), band, nMappedThreads);
356413 }
357414
358- auto ctx = band->ctx_ ;
359- insertNodeBelow (band, detail::ScheduleTree::makeThreadSpecificMarker (ctx));
360-
361415 CHECK_GT (nMappedThreads, 0 ) << " not mapping to threads" ;
362- CHECK_LE (nMappedThreads, 3 ) << " mapping to too many threads" ;
363416
364- // Map the coincident dimensions to threads starting from the innermost and
365- // from thread x.
366- for (size_t i = 0 ; i < nMappedThreads; ++i) {
367- auto id = mapping::ThreadId::makeId (i);
368- auto dim = nMappedThreads - 1 - i;
369- band = map (band, dim, id);
370- }
371- mapRemaining<mapping::ThreadId>(band, nMappedThreads);
417+ mapThreadsBackward (band);
372418
373419 if (isReduction) {
374420 splitOutReductionAndInsertSyncs (band, nMappedThreads - 1 );
0 commit comments