@@ -461,45 +461,13 @@ bool hasOuterSequentialMember(
461461 return false ;
462462}
463463
464+ // Name of the space of blocks inside the grid
465+ constexpr auto kGrid = " grid" ;
464466// Name of the space of threads inside a block
465467constexpr auto kBlock = " block" ;
466468// Name of the space of warps
467469constexpr auto kWarp = " warp" ;
468470
469- /*
470- * Extract a mapping from the domain elements active at "tree"
471- * to the thread identifiers, where all branches in "tree"
472- * are assumed to have been mapped to thread identifiers.
473- * "nThread" is the number of thread identifiers.
474- * The result lives in a space of the form block[x, ...].
475- */
476- isl::multi_union_pw_aff extractDomainToThread (
477- const detail::ScheduleTree* tree,
478- size_t nThread) {
479- using namespace polyhedral ::detail;
480-
481- auto space = isl::space (tree->ctx_ , 0 );
482- auto empty = isl::union_set::empty (space);
483- auto id = isl::id (tree->ctx_ , kBlock );
484- space = space.named_set_from_params_id (id, nThread);
485- auto zero = isl::multi_val::zero (space);
486- auto domainToThread = isl::multi_union_pw_aff (empty, zero);
487-
488- for (auto mapping : tree->collect (tree, ScheduleTreeType::MappingFilter)) {
489- auto mappingNode = mapping->elemAs <ScheduleTreeElemMappingFilter>();
490- auto list = isl::union_pw_aff_list (tree->ctx_ , nThread);
491- for (size_t i = 0 ; i < nThread; ++i) {
492- auto threadId = mapping::ThreadId::makeId (i);
493- auto threadMap = mappingNode->mapping .at (threadId);
494- list = list.add (threadMap);
495- }
496- auto nodeToThread = isl::multi_union_pw_aff (space, list);
497- domainToThread = domainToThread.union_add (nodeToThread);
498- }
499-
500- return domainToThread;
501- }
502-
503471/*
504472 * Construct a mapping
505473 *
@@ -534,6 +502,26 @@ isl::multi_aff constructThreadToWarp(
534502}
535503} // namespace
536504
505+ isl::multi_union_pw_aff MappedScop::threadMappingSchedule (
506+ const detail::ScheduleTree* tree) const {
507+ std::vector<mapping::MappingId> ids;
508+ for (size_t i = 0 ; i < numThreads.view .size (); ++i) {
509+ ids.emplace_back (mapping::ThreadId::makeId (i));
510+ }
511+ auto tupleId = isl::id (tree->ctx_ , kBlock );
512+ return extractDomainToIds (scop_->scheduleRoot (), tree, ids, tupleId);
513+ }
514+
515+ isl::multi_union_pw_aff MappedScop::blockMappingSchedule (
516+ const detail::ScheduleTree* tree) const {
517+ std::vector<mapping::MappingId> ids;
518+ for (size_t i = 0 ; i < numBlocks.view .size (); ++i) {
519+ ids.emplace_back (mapping::BlockId::makeId (i));
520+ }
521+ auto tupleId = isl::id (tree->ctx_ , kGrid );
522+ return extractDomainToIds (scop_->scheduleRoot (), tree, ids, tupleId);
523+ }
524+
537525Scop::SyncLevel MappedScop::findBestSync (
538526 detail::ScheduleTree* st1,
539527 detail::ScheduleTree* st2,
@@ -724,7 +712,7 @@ void MappedScop::insertBestSyncInSeq(detail::ScheduleTree* seq) {
724712
725713 auto outer = hasOuterSequentialMember (scop_->scheduleRoot (), seq);
726714
727- auto domainToThread = extractDomainToThread (seq, numThreads. view . size () );
715+ auto domainToThread = threadMappingSchedule (seq);
728716 auto threadToWarp = constructThreadToWarp (seq->ctx_ , 32 , numThreads);
729717 auto domainToWarp = domainToThread.apply (threadToWarp);
730718
@@ -1080,7 +1068,7 @@ std::unique_ptr<MappedScop> MappedScop::makeWithOuterBlockInnerThreadStrategy(
10801068
10811069 // 9. Promote to registers below the loops mapped to threads.
10821070 if (cudaOptions.proto ().use_private_memory ()) {
1083- promoteToRegistersBelowThreads (mappedScop-> scop () , -1ull );
1071+ promoteToRegistersBelowThreads (* mappedScop, -1ull );
10841072 }
10851073
10861074 LOG_IF (INFO, FLAGS_debug_tc_mapper)
0 commit comments