@@ -124,14 +124,6 @@ void MappedScop::mapRemaining(detail::ScheduleTree* tree, size_t nMapped) {
124124 auto filter = makeFixRemainingZeroFilter (domain, ids);
125125 auto mapping = detail::ScheduleTree::makeMappingFilter (filter, ids);
126126 insertNodeAbove (root, tree, std::move (mapping));
127-
128- for (size_t i = nMapped; i < nToMap; ++i) {
129- if (MappingTypeId::makeId (i) == mapping::ThreadId::x ()) {
130- threadIdxXScheduleDepthState.emplace_back (std::make_pair (
131- activeDomainPoints (schedule (), tree),
132- tree->scheduleDepth (schedule ())));
133- }
134- }
135127}
136128
137129// Uses as many blockSizes elements as outer coincident dimensions in the
@@ -161,6 +153,7 @@ void MappedScop::mapToBlocksAndScaleBand(
161153 * Given a node in the schedule tree of a mapped scop,
162154 * insert a mapping filter underneath (if needed) that fixes
163155 * the remaining thread identifiers starting at "begin" to zero.
156+ * Add a marker underneath that marks the subtree that is thread specific.
164157 */
165158void fixThreadsBelow (
166159 MappedScop& mscop,
@@ -173,6 +166,9 @@ void fixThreadsBelow(
173166
174167 auto band = detail::ScheduleTree::makeEmptyBand (mscop.scop ().scheduleRoot ());
175168 auto bandTree = insertNodeBelow (tree, std::move (band));
169+ auto ctx = tree->ctx_ ;
170+ insertNodeBelow (
171+ bandTree, detail::ScheduleTree::makeThreadSpecificMarker (ctx));
176172 mscop.mapRemaining <mapping::ThreadId>(bandTree, begin);
177173}
178174
@@ -338,8 +334,29 @@ size_t MappedScop::mapToThreads(detail::ScheduleTree* band) {
338334 return 0 ;
339335 }
340336
341- auto nMappedThreads =
342- std::min (numThreads.view .size (), static_cast <size_t >(nCanMap));
337+ auto nMappedThreads = nCanMap;
338+ if (nMappedThreads > numThreads.view .size ()) {
339+ // Split band such that mapping filters get inserted
340+ // right above the first member mapped to a thread identifier.
341+ nMappedThreads = numThreads.view .size ();
342+ bandSplit (scop_->scheduleRoot (), band, nCanMap - nMappedThreads);
343+ auto child = band->child ({0 });
344+ if (isReduction) {
345+ // Update reductionBandUpdates_ such that splitOutReductionAndInsertSyncs
346+ // can find the information it needs.
347+ reductionBandUpdates_.emplace (child, reductionBandUpdates_.at (band));
348+ reductionBandUpdates_.erase (band);
349+ }
350+ band = child;
351+ bandNode = band->elemAs <ScheduleTreeElemBand>();
352+ }
353+
354+ if (nMappedThreads < bandNode->nMember ()) {
355+ bandSplit (scop_->scheduleRoot (), band, nMappedThreads);
356+ }
357+
358+ auto ctx = band->ctx_ ;
359+ insertNodeBelow (band, detail::ScheduleTree::makeThreadSpecificMarker (ctx));
343360
344361 CHECK_GT (nMappedThreads, 0 ) << " not mapping to threads" ;
345362 CHECK_LE (nMappedThreads, 3 ) << " mapping to too many threads" ;
@@ -348,20 +365,16 @@ size_t MappedScop::mapToThreads(detail::ScheduleTree* band) {
348365 // from thread x.
349366 for (size_t i = 0 ; i < nMappedThreads; ++i) {
350367 auto id = mapping::ThreadId::makeId (i);
351- auto dim = nCanMap - 1 - i;
352- if (id == mapping::ThreadId::x ()) {
353- threadIdxXScheduleDepthState.emplace_back (std::make_pair (
354- activeDomainPoints (schedule (), band),
355- band->scheduleDepth (schedule ()) + dim));
356- }
368+ auto dim = nMappedThreads - 1 - i;
357369 band = map (band, dim, id);
358370 }
371+ mapRemaining<mapping::ThreadId>(band, nMappedThreads);
359372
360373 if (isReduction) {
361- splitOutReductionAndInsertSyncs (band, nCanMap - 1 );
374+ splitOutReductionAndInsertSyncs (band, nMappedThreads - 1 );
362375 }
363376
364- return nMappedThreads ;
377+ return numThreads. view . size () ;
365378}
366379
367380namespace {
@@ -450,9 +463,8 @@ size_t MappedScop::mapInnermostBandsToThreads(detail::ScheduleTree* st) {
450463 // because we cannot map parent bands anyway.
451464 auto nMapped = mapToThreads (st);
452465 if (nMapped > 0 ) {
453- mapRemaining<mapping::ThreadId>(st, nMapped);
454466 markUnroll (scop_->scheduleRoot (), st, unroll);
455- return numThreads. view . size () ;
467+ return nMapped ;
456468 }
457469 } else if (anyNonCoincidentMember (band)) {
458470 // If children were mapped to threads, and this band has a non-coincident
@@ -633,7 +645,7 @@ std::unique_ptr<MappedScop> MappedScop::makeWithOuterBlockInnerThreadStrategy(
633645 auto child = outerBand->child ({0 });
634646 size_t numMappedInnerThreads =
635647 mappedScop->mapInnermostBandsToThreads (child);
636- mappedScop-> mapRemaining <mapping::ThreadId>(child , numMappedInnerThreads);
648+ fixThreadsBelow (*mappedScop, outerBand , numMappedInnerThreads);
637649 LOG_IF (INFO, FLAGS_debug_tc_mapper)
638650 << " After mapping to threads:" << std::endl
639651 << *mappedScop->schedule ();
@@ -677,7 +689,6 @@ std::unique_ptr<MappedScop> MappedScop::makeWithOuterBlockInnerThreadStrategy(
677689
678690 promoteGreedilyAtDepth (
679691 *mappedScop,
680- mappedScop->threadIdxXScheduleDepthState ,
681692 std::min (band->nOuterCoincident (), mappedScop->numBlocks .view .size ()),
682693 sharedMemorySize,
683694 cudaOptions.proto ().unroll_copy_shared () &&
@@ -694,8 +705,7 @@ std::unique_ptr<MappedScop> MappedScop::makeWithOuterBlockInnerThreadStrategy(
694705
695706 // 8. Promote to registers below the loops mapped to threads.
696707 if (cudaOptions.proto ().use_private_memory ()) {
697- promoteToRegistersBelowThreads (
698- mappedScop->scop (), mappedScop->threadIdxXScheduleDepthState , -1ull );
708+ promoteToRegistersBelowThreads (mappedScop->scop (), -1ull );
699709 }
700710
701711 // 9. Insert mapping context
0 commit comments