@@ -91,13 +91,26 @@ isl::union_set makeFixRemainingZeroFilter(
9191bool anyNonCoincidentMember (const detail::ScheduleTreeElemBand* band) {
9292 return band->nOuterCoincident () < band->nMember ();
9393}
94+
95+ /*
96+ * Return a reference to the mapping sizes
97+ * for the mapping of type "MappingTypeId".
98+ */
99+ template <typename MappingTypeId>
100+ const CudaDim& mappingSize (const MappedScop* mscop);
101+ template <>
102+ const CudaDim& mappingSize<mapping::BlockId>(const MappedScop* mscop) {
103+ return mscop->numBlocks ;
104+ }
105+ template <>
106+ const CudaDim& mappingSize<mapping::ThreadId>(const MappedScop* mscop) {
107+ return mscop->numThreads ;
108+ }
94109} // namespace
95110
96111template <typename MappingTypeId>
97- void MappedScop::mapRemaining (
98- detail::ScheduleTree* tree,
99- size_t nMapped,
100- size_t nToMap) {
112+ void MappedScop::mapRemaining (detail::ScheduleTree* tree, size_t nMapped) {
113+ size_t nToMap = mappingSize<MappingTypeId>(this ).view .size ();
101114 if (nMapped >= nToMap) {
102115 return ;
103116 }
@@ -140,52 +153,27 @@ void MappedScop::mapToBlocksAndScaleBand(
140153 for (size_t i = 0 ; i < nBlocksToMap; ++i) {
141154 band = map (band, i, mapping::BlockId::makeId (i));
142155 }
143- mapRemaining<mapping::BlockId>(band, nBlocksToMap, numBlocks. view . size () );
156+ mapRemaining<mapping::BlockId>(band, nBlocksToMap);
144157 bandScale (band, tileSizes);
145158}
146159
147160/*
148- * Given a filter node in the schedule tree of a mapped scop,
149- * insert another filter underneath (if needed) that fixes
150- * the thread identifiers in the range [ begin, end) to zero.
161+ * Given a node in the schedule tree of a mapped scop,
162+ * insert a mapping filter underneath (if needed) that fixes
163+ * the remaining thread identifiers starting at " begin" to zero.
151164 */
152- void fixThreadsBelowFilter (
165+ void fixThreadsBelow (
153166 MappedScop& mscop,
154- detail::ScheduleTree* filterTree ,
155- size_t begin,
156- size_t end) {
167+ detail::ScheduleTree* tree ,
168+ size_t begin) {
169+ size_t end = mscop. numThreads . view . size ();
157170 if (begin == end) {
158171 return ;
159172 }
160173
161- std::unordered_set<mapping::ThreadId, mapping::ThreadId::Hash> ids;
162- for (size_t i = begin; i < end; ++i) {
163- ids.insert (mapping::ThreadId::makeId (i));
164- }
165- auto root = mscop.schedule ();
166- auto domain = activeDomainPoints (root, filterTree);
167- auto mappingFilter = makeFixRemainingZeroFilter (domain, ids);
168- auto filter = filterTree->elemAs <detail::ScheduleTreeElemFilter>();
169- CHECK (filter) << " Not a filter: " << *filter;
170- // Active domain points will contain spaces for different statements
171- // When inserting below a leaf filter, this would break the tightening
172- // invariant that leaf mapping filters have a single space.
173- // So we intersect with the universe set of the filter to only keep the
174- // space for the legitimate statement.
175- mappingFilter = mappingFilter & filter->filter_ .universe ();
176- auto mapping = detail::ScheduleTree::makeMappingFilter (mappingFilter, ids);
177- insertNodeBelow (filterTree, std::move (mapping));
178-
179- for (size_t i = begin; i < end; ++i) {
180- if (mapping::ThreadId::makeId (i) == mapping::ThreadId::x ()) {
181- // Mapping happened below filterTree, so we need points active for its
182- // children. After insertion, filterTree is guaranteed to have at least
183- // one child.
184- mscop.threadIdxXScheduleDepthState .emplace_back (std::make_pair (
185- activeDomainPoints (mscop.schedule (), filterTree->child ({0 })),
186- filterTree->scheduleDepth (mscop.schedule ())));
187- }
188- }
174+ auto band = detail::ScheduleTree::makeEmptyBand (mscop.scop ().scheduleRoot ());
175+ auto bandTree = insertNodeBelow (tree, std::move (band));
176+ mscop.mapRemaining <mapping::ThreadId>(bandTree, begin);
189177}
190178
191179bool MappedScop::detectReductions (detail::ScheduleTree* tree) {
@@ -239,7 +227,7 @@ bool MappedScop::detectReductions(detail::ScheduleTree* tree) {
239227 if (!inits.is_empty ()) {
240228 orderBefore (scop_->scheduleRoot (), tree, inits);
241229 }
242- reductionBandUpdates_.emplace (tree, Reduction (updateIds, reductionDim ));
230+ reductionBandUpdates_.emplace (tree, Reduction (updateIds));
243231 return true ;
244232}
245233
@@ -261,11 +249,9 @@ isl::multi_union_pw_aff MappedScop::reductionMapSchedule(
261249 // mapped to threads.
262250 auto reductionSchedule = reductionBand->mupa_ ;
263251 auto nMember = reductionBand->nMember ();
264- auto reductionDim = reductionBandUpdates_.at (st).reductionDim ;
265- auto nMappedThreads =
266- std::min (numThreads.view .size (), reductionBand->nOuterCoincident () + 1 );
252+ auto reductionDim = reductionBand->nOuterCoincident ();
253+ auto nMappedThreads = std::min (numThreads.view .size (), reductionDim + 1 );
267254 CHECK_GE (nMember, reductionDim);
268- CHECK_GE (reductionDim + 1 , nMappedThreads);
269255 reductionSchedule = reductionSchedule.drop_dims (
270256 isl::dim_type::set, reductionDim + 1 , nMember - (reductionDim + 1 ));
271257 reductionSchedule = reductionSchedule.drop_dims (
@@ -332,45 +318,37 @@ size_t MappedScop::mapToThreads(detail::ScheduleTree* band) {
332318 return 0 ;
333319 }
334320
335- size_t nMappedReductionThreads = 0 ;
336- if (reductionBandUpdates_.count (band) == 1 ) {
337- // A reduction is assumed to get mapped to threadIdx.x
338- CHECK (reductionBandUpdates_.at (band).separated );
339- auto reductionDim = reductionBandUpdates_.at (band).reductionDim ;
340- threadIdxXScheduleDepthState.emplace_back (std::make_pair (
341- activeDomainPoints (schedule (), band),
342- band->scheduleDepth (schedule ()) + reductionDim));
343- band = map (band, reductionDim, mapping::ThreadId::x ());
344- nMappedReductionThreads = 1 ;
345- }
346-
347321 // With current isl scheduler, if coincident dimensions exist in a band,
348322 // they are outermost.
349323 // If a band has more than 3 coincident dimensions,
350324 // then the innermost of those will be used.
351- auto nOuterCoincident = bandNode->nOuterCoincident ();
352- if (nOuterCoincident < 1 ) {
353- return nMappedReductionThreads;
325+ auto nCanMap = bandNode->nOuterCoincident ();
326+
327+ auto isReduction = reductionBandUpdates_.count (band) == 1 ;
328+ // If the band has a detected reduction, then the first member
329+ // after the coincident members is the reduction member and
330+ // this member has to be mapped as well.
331+ // In particular, it will get mapped to threadIdx.x
332+ if (isReduction) {
333+ CHECK (reductionBandUpdates_.at (band).separated );
334+ nCanMap++;
354335 }
355336
356- auto nMappedThreads = std::min (
357- numThreads.view .size () - nMappedReductionThreads,
358- static_cast <size_t >(nOuterCoincident));
359-
360- // Immediately return if mapping to one thread dimension only was requested
361- // and a reduction was already mapped. (Note that reduction is detected only
362- // if there are not enough outer coincident members, 0 in this case).
363- if (nMappedThreads == 0 ) {
364- return nMappedReductionThreads;
337+ if (nCanMap < 1 ) {
338+ return 0 ;
365339 }
366- CHECK_LE (nMappedThreads, 3 - nMappedReductionThreads)
367- << " mapping to too many threads" ;
340+
341+ auto nMappedThreads =
342+ std::min (numThreads.view .size (), static_cast <size_t >(nCanMap));
343+
344+ CHECK_GT (nMappedThreads, 0 ) << " not mapping to threads" ;
345+ CHECK_LE (nMappedThreads, 3 ) << " mapping to too many threads" ;
368346
369347 // Map the coincident dimensions to threads starting from the innermost and
370- // from thread x unless it was already mapped to a reduction .
348+ // from thread x.
371349 for (size_t i = 0 ; i < nMappedThreads; ++i) {
372- auto id = mapping::ThreadId::makeId (nMappedReductionThreads + i);
373- auto dim = nOuterCoincident - 1 - i;
350+ auto id = mapping::ThreadId::makeId (i);
351+ auto dim = nCanMap - 1 - i;
374352 if (id == mapping::ThreadId::x ()) {
375353 threadIdxXScheduleDepthState.emplace_back (std::make_pair (
376354 activeDomainPoints (schedule (), band),
@@ -379,7 +357,11 @@ size_t MappedScop::mapToThreads(detail::ScheduleTree* band) {
379357 band = map (band, dim, id);
380358 }
381359
382- return nMappedReductionThreads + nMappedThreads;
360+ if (isReduction) {
361+ splitOutReductionAndInsertSyncs (band, nCanMap - 1 );
362+ }
363+
364+ return nMappedThreads;
383365}
384366
385367namespace {
@@ -419,21 +401,16 @@ bool hasOuterSequentialMember(
419401// If any separation is needed for mapping reductions to full blocks,
420402// then do so first.
421403//
422- // If "st" has multiple children, then make sure they are mapped
423- // to the same number of thread identifiers by fixing those
424- // that are originally mapped to fewer identifiers to value zero
425- // for the remaining thread identifiers.
404+ // If "st" has multiple children and if any of those children
405+ // is mapped to threads, then make sure the other children
406+ // are also mapped to threads, by fixing the thread identifiers to value zero.
426407// If, moreover, "st" is a sequence node and at least one of its
427408// children is mapped to threads, then introduce synchronization
428409// before and after children that are mapped to threads.
429410// Also add synchronization between the last child and
430411// the next iteration of the first child if there may be such
431412// a next iteration that is not already covered by synchronization
432413// on an outer node.
433- // If any synchronization is introduced, then the mapping
434- // to threads needs to be completed to all thread ids
435- // because the synchronization needs to be introduced outside
436- // any mapping to threads.
437414size_t MappedScop::mapInnermostBandsToThreads (detail::ScheduleTree* st) {
438415 if (needReductionSeparation (st)) {
439416 st = separateReduction (st);
@@ -447,11 +424,10 @@ size_t MappedScop::mapInnermostBandsToThreads(detail::ScheduleTree* st) {
447424 auto n = nChildren > 0 ? *std::max_element (nInner.begin (), nInner.end ()) : 0 ;
448425 if (nChildren > 1 ) {
449426 auto needSync = st->elemAs <detail::ScheduleTreeElemSequence>() && n > 0 ;
450- if (needSync) {
451- n = numThreads.view .size ();
452- }
453- for (size_t i = 0 ; i < nChildren; ++i) {
454- fixThreadsBelowFilter (*this , children[i], nInner[i], n);
427+ if (n > 0 ) {
428+ for (size_t i = 0 ; i < nChildren; ++i) {
429+ fixThreadsBelow (*this , children[i], nInner[i]);
430+ }
455431 }
456432 if (needSync) {
457433 auto outer = hasOuterSequentialMember (scop_->scheduleRoot (), st);
@@ -474,7 +450,7 @@ size_t MappedScop::mapInnermostBandsToThreads(detail::ScheduleTree* st) {
474450 // because we cannot map parent bands anyway.
475451 auto nMapped = mapToThreads (st);
476452 if (nMapped > 0 ) {
477- mapRemaining<mapping::ThreadId>(st, nMapped, numThreads. view . size () );
453+ mapRemaining<mapping::ThreadId>(st, nMapped);
478454 markUnroll (scop_->scheduleRoot (), st, unroll);
479455 return numThreads.view .size ();
480456 }
@@ -594,19 +570,16 @@ std::tuple<std::string, tc::Grid, tc::Block> MappedScop::codegen(
594570 mappedScopForCodegen->numThreads );
595571}
596572
597- // Split out reduction loops into separate bands and insert reduction
598- // synchronizations outside those bands.
599- void MappedScop::splitOutReductionsAndInsertSyncs () {
573+ // Split out reduction member at position "dim" in "band" and
574+ // insert reduction synchronizations outside this split off band.
575+ void MappedScop::splitOutReductionAndInsertSyncs (
576+ detail::ScheduleTree* band,
577+ int dim) {
600578 using namespace polyhedral ::detail;
601579
602- for (auto bandUpdate : reductionBandUpdates_) {
603- auto tree = bandSplitOut (
604- scop_->scheduleRoot (),
605- const_cast <ScheduleTree*>(bandUpdate.first ),
606- bandUpdate.second .reductionDim );
607- for (auto updateId : bandUpdate.second .ids ) {
608- scop_->insertReductionSync1D (tree, updateId);
609- }
580+ auto tree = bandSplitOut (scop_->scheduleRoot (), band, dim);
581+ for (auto updateId : reductionBandUpdates_.at (band).ids ) {
582+ scop_->insertReductionSync1D (tree, updateId);
610583 }
611584}
612585
@@ -660,8 +633,7 @@ std::unique_ptr<MappedScop> MappedScop::makeWithOuterBlockInnerThreadStrategy(
660633 auto child = outerBand->child ({0 });
661634 size_t numMappedInnerThreads =
662635 mappedScop->mapInnermostBandsToThreads (child);
663- mappedScop->mapRemaining <mapping::ThreadId>(
664- child, numMappedInnerThreads, mappedScop->numThreads .view .size ());
636+ mappedScop->mapRemaining <mapping::ThreadId>(child, numMappedInnerThreads);
665637 LOG_IF (INFO, FLAGS_debug_tc_mapper)
666638 << " After mapping to threads:" << std::endl
667639 << *mappedScop->schedule ();
@@ -673,13 +645,7 @@ std::unique_ptr<MappedScop> MappedScop::makeWithOuterBlockInnerThreadStrategy(
673645 LOG_IF (INFO, FLAGS_debug_tc_mapper) << " After mapping to blocks:" << std::endl
674646 << *mappedScop->schedule ();
675647
676- // 7. Insert reduction synchronizations if necessary.
677- mappedScop->splitOutReductionsAndInsertSyncs ();
678- LOG_IF (INFO, FLAGS_debug_tc_mapper)
679- << " After inserting reduction synchronization:" << std::endl
680- << *mappedScop->schedule ();
681-
682- // 8. Promote to shared memory below the loops mapped to blocks.
648+ // 7. Promote to shared memory below the loops mapped to blocks.
683649 // This may split the outer band, so find the new outer band after promotion.
684650 if (cudaOptions.proto ().use_shared_memory ()) {
685651 size_t sharedMemorySize = cudaOptions.proto ().has_max_shared_memory ()
@@ -726,13 +692,13 @@ std::unique_ptr<MappedScop> MappedScop::makeWithOuterBlockInnerThreadStrategy(
726692 }
727693 }
728694
729- // 9 . Promote to registers below the loops mapped to threads.
695+ // 8 . Promote to registers below the loops mapped to threads.
730696 if (cudaOptions.proto ().use_private_memory ()) {
731697 promoteToRegistersBelowThreads (
732698 mappedScop->scop (), mappedScop->threadIdxXScheduleDepthState , -1ull );
733699 }
734700
735- // 10 . Insert mapping context
701+ // 9 . Insert mapping context
736702 mappedScop->insertMappingContext ();
737703 LOG_IF (INFO, FLAGS_debug_tc_mapper)
738704 << " After outerBlockInnerThread strategy:" << std::endl
0 commit comments