@@ -197,6 +197,11 @@ bool MappedScop::detectReductions(detail::ScheduleTree* tree) {
197197 return found;
198198 }
199199
200+ // Only reductions that appear in permutable bands are mapped to threads.
201+ if (!band->permutable_ ) {
202+ return false ;
203+ }
204+
200205 // For now, only support reductions with a sufficient number
201206 // of coincident outer band members for the remaining thread identifiers.
202207 auto nCoincident = band->nOuterCoincident ();
@@ -225,65 +230,50 @@ bool MappedScop::detectReductions(detail::ScheduleTree* tree) {
225230 if (!isReductionMember (member, updates, scop ())) {
226231 return false ;
227232 }
228- auto reductionTree = bandSplitOut (scop_->scheduleRoot (), tree, reductionDim);
229233 // Order the init statements (if any) before the update statements
230234 // to ensure the band from which the reduction band has been split off
231235 // only contains update statements.
232236 // Note that this relies on the outer members being coincident.
233237 if (!inits.is_empty ()) {
234238 orderBefore (scop_->scheduleRoot (), tree, inits);
235239 }
236- reductionFromParent_.emplace (tree, reductionTree);
237- reductionBandUpdates_.emplace (reductionTree, updateIds);
240+ reductionBandUpdates_.emplace (tree, Reduction (updateIds, reductionDim));
238241 return true ;
239242}
240243
241244bool MappedScop::needReductionSeparation (const detail::ScheduleTree* st) {
242- // It is the parent band of the reduction band that needs to be separated.
243- if (reductionFromParent_.count (st) != 1 ) {
245+ if (reductionBandUpdates_.count (st) != 1 ) {
244246 return false ;
245247 }
246- st = reductionFromParent_.at (st);
247- CHECK (reductionBandUpdates_.count (st) == 1 );
248248 // No need to separate if already separated.
249249 return !reductionBandUpdates_.at (st).separated ;
250250}
251251
252252isl::multi_union_pw_aff MappedScop::reductionMapSchedule (
253253 const detail::ScheduleTree* st) {
254- CHECK (reductionFromParent_.count (st) == 1 );
255- auto parent = st;
256- st = reductionFromParent_.at (st);
257254 CHECK (reductionBandUpdates_.count (st) == 1 );
258-
259255 auto reductionBand = st->elemAs <detail::ScheduleTreeElemBand>();
260256 CHECK (reductionBand);
261- // Start with the schedule of the reduction band (in last position).
262- auto reductionSchedule = reductionBand->mupa_ ;
263257
264- // Total size of returned schedule needs to be equal
265- // to the number of thread identifiers.
266- if (numThreads.view .size () > 1 ) {
267- CHECK (parent != st);
268- }
269- // Prepend last members of parent band (if any).
270- if (parent != st) {
271- auto parentBand = parent->elemAs <detail::ScheduleTreeElemBand>();
272- CHECK (parentBand);
273- auto parentSchedule = parentBand->mupa_ ;
274- auto nMember = parentBand->nMember ();
275- CHECK_GE (nMember, numThreads.view .size () - 1 );
276- parentSchedule = parentSchedule.drop_dims (
277- isl::dim_type::set, 0 , nMember - (numThreads.view .size () - 1 ));
278- reductionSchedule = parentSchedule.flat_range_product (reductionSchedule);
279- }
258+ // Drop band members following the reduction dimension and preceding those
259+ // mapped to threads.
260+ auto reductionSchedule = reductionBand->mupa_ ;
261+ auto nMember = reductionBand->nMember ();
262+ auto reductionDim = reductionBandUpdates_.at (st).reductionDim ;
263+ auto nMappedThreads =
264+ std::min (numThreads.view .size (), reductionBand->nOuterCoincident () + 1 );
265+ CHECK_GE (nMember, reductionDim);
266+ CHECK_GE (reductionDim + 1 , nMappedThreads);
267+ reductionSchedule = reductionSchedule.drop_dims (
268+ isl::dim_type::set, reductionDim + 1 , nMember - (reductionDim + 1 ));
269+ reductionSchedule = reductionSchedule.drop_dims (
270+ isl::dim_type::set, 0 , reductionDim - nMappedThreads + 1 );
280271
281272 return reductionSchedule;
282273}
283274
284275detail::ScheduleTree* MappedScop::separateReduction (detail::ScheduleTree* st) {
285- CHECK (reductionFromParent_.count (st) == 1 );
286- auto reduction = reductionFromParent_.at (st);
276+ auto reduction = st;
287277 // This function either separates full blocks (if needed) or
288278 // disables the reduction handling.
289279 reductionBandUpdates_.at (reduction).separated = true ;
@@ -331,59 +321,54 @@ detail::ScheduleTree* MappedScop::separateReduction(detail::ScheduleTree* st) {
331321 return st->ancestor (root, 2 );
332322}
333323
334- size_t MappedScop::mapToThreads (detail::ScheduleTree* band, size_t nInner ) {
324+ size_t MappedScop::mapToThreads (detail::ScheduleTree* band) {
335325 using namespace tc ::polyhedral::detail;
336326
337- if (nInner >= numThreads.view .size ()) {
338- return nInner;
327+ auto bandNode = band->elemAs <ScheduleTreeElemBand>();
328+ // Cannot map non-permutable bands.
329+ if (!bandNode->permutable_ ) {
330+ return 0 ;
339331 }
332+
333+ int nMappedReductionThreads = 0 ;
340334 if (reductionBandUpdates_.count (band) == 1 ) {
341335 // A reduction is assumed to get mapped to threadIdx.x
342- if (nInner != 0 ) {
343- reductionBandUpdates_.erase (band);
344- return nInner;
345- }
346336 CHECK (reductionBandUpdates_.at (band).separated );
347337 threadIdxXScheduleDepthState.emplace_back (std::make_pair (
348338 activeDomainPoints (schedule (), band),
349339 band->scheduleDepth (schedule ()) + 0 ));
350- band = map (band, 0 , mapping::ThreadId::x ());
351- markUnroll (scop_->scheduleRoot (), band, unroll);
352- return 1 ;
353- }
354- auto bandNode = band->elemAs <ScheduleTreeElemBand>();
355- // If any inner node was mapped to threads and
356- // the current node has a non-coincident member,
357- // then synchronization needs to be introduced.
358- // This also implies that the mapping needs to be completed first.
359- if (anyNonCoincidentMember (bandNode) && nInner > 0 ) {
360- // Since some thread identifiers were mapped already (nInner > 0),
361- // the band should have descendants. Double check.
362- CHECK_EQ (band->numChildren (), 1 );
363- mapRemaining<mapping::ThreadId>(
364- band->child ({0 }), nInner, numThreads.view .size ());
365- scop_->insertSyncAfter (band->child ({0 }));
366- return numThreads.view .size ();
340+ auto reductionDim = reductionBandUpdates_.at (band).reductionDim ;
341+ band = map (band, reductionDim, mapping::ThreadId::x ());
342+ nMappedReductionThreads = 1 ;
367343 }
344+
368345 // With current isl scheduler, if coincident dimensions exist in a band,
369346 // they are outermost.
370- // If a band has more than 3 coincident dimensions, this will choose
371- // outermost, but we may also want innermost .
347+ // If a band has more than 3 coincident dimensions,
348+ // then the innermost of those will be used .
372349 auto nOuterCoincident = bandNode->nOuterCoincident ();
373- if (!bandNode-> permutable_ || nOuterCoincident < 1 ) {
374- return nInner ;
350+ if (nOuterCoincident < 1 ) {
351+ return nMappedReductionThreads ;
375352 }
376353
377354 auto nMappedThreads = std::min (
378- numThreads.view .size () - nInner, static_cast <size_t >(nOuterCoincident));
379- CHECK_GT (nMappedThreads, 0 ) << " not mapping to threads" ;
380- CHECK_LE (nMappedThreads, 3 - nInner) << " mapping to too many threads" ;
355+ numThreads.view .size () - nMappedReductionThreads,
356+ static_cast <size_t >(nOuterCoincident));
357+
358+ // Immediately return if mapping to one thread dimension only was requested
359+ // and a reduction was already mapped. (Note that reduction is detected only
360+ // if there are not enough outer coincident members, 0 in this case).
361+ if (nMappedThreads == 0 ) {
362+ return nMappedReductionThreads;
363+ }
364+ CHECK_LE (nMappedThreads, 3 - nMappedReductionThreads)
365+ << " mapping to too many threads" ;
381366
382367 // Map the coincident dimensions to threads starting from the innermost and
383- // from thread x.
384- for (int i = 0 , dim = nOuterCoincident - 1 ; i < nMappedThreads && dim >= 0 ;
385- ++i, --dim) {
386- auto id = mapping::ThreadId::makeId (nInner + i) ;
368+ // from thread x unless it was already mapped to a reduction .
369+ for (int i = 0 ; i < nMappedThreads; ++i) {
370+ auto id = mapping::ThreadId::makeId (nMappedReductionThreads + i);
371+ auto dim = nOuterCoincident - 1 - i ;
387372 if (id == mapping::ThreadId::x ()) {
388373 threadIdxXScheduleDepthState.emplace_back (std::make_pair (
389374 activeDomainPoints (schedule (), band),
@@ -392,11 +377,7 @@ size_t MappedScop::mapToThreads(detail::ScheduleTree* band, size_t nInner) {
392377 band = map (band, dim, id);
393378 }
394379
395- if (nInner == 0 ) {
396- markUnroll (scop_->scheduleRoot (), band, unroll);
397- }
398-
399- return nInner + nMappedThreads;
380+ return nMappedReductionThreads + nMappedThreads;
400381}
401382
402383namespace {
@@ -426,8 +407,12 @@ bool hasOuterSequentialMember(
426407}
427408} // namespace
428409
429- // Maps bands to threads in DFS postorder, keeping track of
430- // the (maximal) number of threads already mapped by descendants.
410+ // Maps bands to threads in DFS postorder.
411+ // Mapping is only allowed if descendants are not already mapped to threads.
412+ // Mapping nested bands to threads is invalid because members of those bands
413+ // are not necessarily permutable, and there is no guaranteed nesting between
414+ // thread dimensions (e.g., there is no guarantee that all threads with
415+ // threadIdx.y=0 will be executed before any thread with threadIdx.y=1).
431416//
432417// If any separation is needed for mapping reductions to full blocks,
433418// then do so first.
@@ -479,8 +464,28 @@ size_t MappedScop::mapInnermostBandsToThreads(detail::ScheduleTree* st) {
479464 }
480465 }
481466
482- if (st->elemAs <detail::ScheduleTreeElemBand>()) {
483- n = mapToThreads (st, n);
467+ if (auto band = st->elemAs <detail::ScheduleTreeElemBand>()) {
468+ if (n == 0 ) {
469+ // If children were not mapped to threads, the current band can be mapped.
470+ // First, map the coincidence and reduction dimension to threads.
471+ // Then, if some threads were mapped, fix unused thread dimensions to 0
472+ // because we cannot map parent bands anyway.
473+ auto nMapped = mapToThreads (st);
474+ if (nMapped > 0 ) {
475+ mapRemaining<mapping::ThreadId>(st, nMapped, numThreads.view .size ());
476+ markUnroll (scop_->scheduleRoot (), st, unroll);
477+ return numThreads.view .size ();
478+ }
479+ } else if (anyNonCoincidentMember (band)) {
480+ // If children were mapped to threads, and this band has a non-coincident
481+ // member, insert a synchronization after its last child.
482+ // The node must have children if some of them were mapped to threads,
483+ // double-check. Note that a band node has at most one child.
484+ CHECK_EQ (st->numChildren (), 1 );
485+ // The mapping should be always complete, double-check.
486+ CHECK_EQ (n, numThreads.view .size ());
487+ scop_->insertSyncAfter (st->child ({0 }));
488+ }
484489 }
485490
486491 return n;
@@ -587,6 +592,22 @@ std::tuple<std::string, tc::Grid, tc::Block> MappedScop::codegen(
587592 mappedScopForCodegen->numThreads );
588593}
589594
595+ // Split out reduction loops into separate bands and insert reduction
596+ // synchronizations outside those bands.
597+ void MappedScop::splitOutReductionsAndInsertSyncs () {
598+ using namespace polyhedral ::detail;
599+
600+ for (auto bandUpdate : reductionBandUpdates_) {
601+ auto tree = bandSplitOut (
602+ scop_->scheduleRoot (),
603+ const_cast <ScheduleTree*>(bandUpdate.first ),
604+ bandUpdate.second .reductionDim );
605+ for (auto updateId : bandUpdate.second .ids ) {
606+ scop_->insertReductionSync1D (tree, updateId);
607+ }
608+ }
609+ }
610+
590611std::unique_ptr<MappedScop> MappedScop::makeWithOuterBlockInnerThreadStrategy (
591612 std::unique_ptr<Scop>&& scopUPtr,
592613 const CudaMappingOptions& cudaOptions) {
@@ -650,7 +671,13 @@ std::unique_ptr<MappedScop> MappedScop::makeWithOuterBlockInnerThreadStrategy(
650671 LOG_IF (INFO, FLAGS_debug_tc_mapper) << " After mapping to blocks:" << std::endl
651672 << *mappedScop->schedule ();
652673
653- // 7. Promote to shared memory below the loops mapped to blocks.
674+ // 7. Insert reduction synchronizations if necessary.
675+ mappedScop->splitOutReductionsAndInsertSyncs ();
676+ LOG_IF (INFO, FLAGS_debug_tc_mapper)
677+ << " After inserting reduction synchronization:" << std::endl
678+ << *mappedScop->schedule ();
679+
680+ // 8. Promote to shared memory below the loops mapped to blocks.
654681 // This may split the outer band, so find the new outer band after promotion.
655682 if (cudaOptions.proto ().use_shared_memory ()) {
656683 size_t sharedMemorySize = cudaOptions.proto ().has_max_shared_memory ()
@@ -697,24 +724,16 @@ std::unique_ptr<MappedScop> MappedScop::makeWithOuterBlockInnerThreadStrategy(
697724 }
698725 }
699726
700- // 8 . Promote to registers below the loops mapped to threads.
727+ // 9 . Promote to registers below the loops mapped to threads.
701728 if (cudaOptions.proto ().use_private_memory ()) {
702729 promoteToRegistersBelowThreads (
703730 mappedScop->scop (), mappedScop->threadIdxXScheduleDepthState , -1ull );
704731 }
705732
706- // 9 . Insert mapping context
733+ // 10 . Insert mapping context
707734 mappedScop->insertMappingContext ();
708-
709- // 10. Optionally insert reduction synchronizations
710- for (auto bandUpdate : mappedScop->reductionBandUpdates_ ) {
711- for (auto updateId : bandUpdate.second .ids ) {
712- scop->insertReductionSync1D (
713- const_cast <ScheduleTree*>(bandUpdate.first ), updateId);
714- }
715- }
716735 LOG_IF (INFO, FLAGS_debug_tc_mapper)
717- << " After inserting reduction synchronization :" << std::endl
736+ << " After outerBlockInnerThread strategy :" << std::endl
718737 << *mappedScop->schedule ();
719738
720739 return mappedScop;
0 commit comments