@@ -431,167 +431,107 @@ bool hasOuterSequentialMember(
431431 return false ;
432432}
433433
434- // Intersect the union set with all the mapping
435- // filters params in the given schedule tree
436- isl::union_set intersectMappingFilterParams (
437- detail::ScheduleTree* st,
438- isl::union_set us) {
439- if (auto filter = st->elemAsBase <detail::ScheduleTreeElemFilter>()) {
440- us = us.intersect (filter->filter_ );
441- }
434+ // Name of the space of threads inside a block
435+ constexpr auto kBlock = " block" ;
436+ // Name of the space of warps
437+ constexpr auto kWarp = " warp" ;
442438
443- auto children = st->children ();
444- auto nChildren = children.size ();
445- if (nChildren == 1 ) {
446- us = intersectMappingFilterParams (children[0 ], us);
447- } else if (nChildren > 1 ) {
448- auto usParent = us;
449- us = intersectMappingFilterParams (children[0 ], us);
450- for (size_t i = 1 ; i < nChildren; ++i) {
451- us = us.unite (intersectMappingFilterParams (children[i], usParent));
452- }
453- }
454-
455- return us;
456- }
439+ /*
440+ * Extract a mapping from the domain elements active at "tree"
441+ * to the thread identifiers, where all branches in "tree"
442+ * are assumed to have been mapped to thread identifiers.
443+ * "nThread" is the number of thread identifiers.
444+ * The result lives in a space of the form block[x, ...].
445+ */
446+ isl::multi_union_pw_aff extractDomainToThread (
447+ const detail::ScheduleTree* tree,
448+ size_t nThread) {
449+ using namespace polyhedral ::detail;
457450
458- // Change the name of the isl ids tied to threads and blocks
459- // by adding a suffix
460- isl::union_set modifyMappingNames (
461- isl::union_set set,
462- const std::string suffix) {
463- USING_MAPPING_SHORT_NAMES (BX, BY, BZ, TX, TY, TZ);
464- std::unordered_set<isl::id, isl::IslIdIslHash> identifiers{
465- BX, BY, BZ, TX, TY, TZ};
466-
467- auto space = set.get_space ();
468- for (auto id : identifiers) {
469- auto name = id.get_name ();
470- auto dim = space.find_dim_by_name (isl::dim_type::param, id.get_name ());
471- CHECK_LE (0 , dim);
472- space = space.set_dim_name (isl::dim_type::param, dim, name + suffix);
473- }
474- auto newSet = isl::union_set::empty (space);
475- set.foreach_set ([&newSet, &identifiers, &suffix](isl::set setInFun) {
476- for (auto id : identifiers) {
477- auto name = id.get_name ();
478- auto dim =
479- setInFun.get_space ().find_dim_by_name (isl::dim_type::param, name);
480- CHECK_LE (0 , dim);
481- setInFun =
482- setInFun.set_dim_name (isl::dim_type::param, dim, name + suffix);
451+ auto space = isl::space (tree->ctx_ , 0 );
452+ auto empty = isl::union_set::empty (space);
453+ auto id = isl::id (tree->ctx_ , kBlock );
454+ space = space.named_set_from_params_id (id, nThread);
455+ auto zero = isl::multi_val::zero (space);
456+ auto domainToThread = isl::multi_union_pw_aff (empty, zero);
457+
458+ for (auto mapping : tree->collect (tree, ScheduleTreeType::MappingFilter)) {
459+ auto mappingNode = mapping->elemAs <ScheduleTreeElemMappingFilter>();
460+ auto list = isl::union_pw_aff_list (tree->ctx_ , nThread);
461+ for (size_t i = 0 ; i < nThread; ++i) {
462+ auto threadId = mapping::ThreadId::makeId (i);
463+ auto threadMap = mappingNode->mapping .at (threadId);
464+ list = list.add (threadMap);
483465 }
484- newSet = newSet.unite (setInFun);
485- });
486- return newSet;
487- }
488-
489- // Get the formula computing the linearized index of a thread in a block.
490- isl::aff getLinearizedThreadIdxFormula (
491- isl::space space,
492- const Block& block,
493- const std::string& suffix = " " ) {
494- USING_MAPPING_SHORT_NAMES (BX, BY, BZ, TX, TY, TZ);
495- std::vector<std::pair<isl::id, unsigned >> mappingIds{
496- {TX, TX.mappingSize (block)},
497- {TY, TY.mappingSize (block)},
498- {TZ, TZ.mappingSize (block)}};
499-
500- isl::aff formula = isl::aff (isl::local_space (space));
501-
502- for (int i = (int )mappingIds.size () - 1 ; i >= 0 ; --i) {
503- auto name = mappingIds[i].first .to_str ();
504- auto dim = space.find_dim_by_name (isl::dim_type::param, name + suffix);
505- CHECK_LE (0 , dim);
506- auto id = space.get_dim_id (isl::dim_type::param, dim);
507- isl::aff aff (isl::aff::param_on_domain_space (space, id));
508- formula = formula * mappingIds[i].second + aff;
466+ auto nodeToThread = isl::multi_union_pw_aff (space, list);
467+ domainToThread = domainToThread.union_add (nodeToThread);
509468 }
510469
511- return formula ;
470+ return domainToThread ;
512471}
513472
514- // Return the constraints ensuring that the points with parameters
515- // [t0,t1,t2] and [t0',t1',t2'] are in the same warp.
516- // (where t0 is "t0" + suffix1 and t0' is "t0" + suffix2)
517- // if suffix1 is "_1" and suffix2 is "_2", the constraint is in the form
518- // ((t0_1 + a * t1_1 + b * t2_1) / warpSize).floor()
519- // == ((t0_2 + a' * t1_2 + b' * t1_2) / warpSize).floor()
520- // with t0_1 + a * t1_1 + b * t2_1 the linearized formula of the thread index.
521- // This function returns a set because it might change in the future,
522- // and take into account the blocks.
523- isl::set getSameWarpConstraints (
524- isl::space space,
525- const std::string& suffix1,
526- const std::string& suffix2 ,
473+ /*
474+ * Construct a mapping
475+ *
476+ * block[x] -> warp[floor((x)/warpSize)]
477+ * block[x, y] -> warp[floor((x + s_x * (y))/ warpSize)]
478+ * block[x, y, z] -> warp[floor((x + s_x * (y + s_y * (z)))/ warpSize)]
479+ *
480+ * uniquely mapping thread identifiers that belong to the same warp
481+ * (of size "warpSize") to a warp identifier,
482+ * based on the thread sizes s_x, s_y up to s_z in "block".
483+ */
484+ isl::multi_aff constructThreadToWarp (
485+ isl::ctx ctx ,
527486 const unsigned warpSize,
528487 const Block& block) {
529- USING_MAPPING_SHORT_NAMES (BX, BY, BZ, TX, TY, TZ);
530- std::vector<std::pair<isl::id, unsigned >> mappingIds{
531- {TX, TX.mappingSize (block)},
532- {TY, TY.mappingSize (block)},
533- {TZ, TZ.mappingSize (block)}};
534-
535- auto formula1 = getLinearizedThreadIdxFormula (space, block, suffix1);
536- auto formula2 = getLinearizedThreadIdxFormula (space, block, suffix2);
488+ auto space = isl::space (ctx, 0 );
489+ auto id = isl::id (ctx, kBlock );
490+ auto blockSpace = space.named_set_from_params_id (id, block.view .size ());
491+ auto warpSpace = space.named_set_from_params_id (isl::id (ctx, kWarp ), 1 );
492+ auto aff = isl::aff::zero_on_domain (blockSpace);
493+
494+ auto nThread = block.view .size ();
495+ auto identity = isl::multi_aff::identity (blockSpace.map_from_set ());
496+ for (int i = nThread - 1 ; i >= 0 ; --i) {
497+ aff = aff.scale (isl::val (ctx, block.view [i]));
498+ aff = aff.add (identity.get_aff (i));
499+ }
537500
538- return (
539- isl::aff_set ((formula1 / warpSize). floor ()) ==
540- (formula2 / warpSize). floor ( ));
501+ aff = aff. scale_down ( isl::val (ctx, warpSize)). floor ();
502+ auto mapSpace = blockSpace. product (warpSpace). unwrap ();
503+ return isl::multi_aff (mapSpace, isl::aff_list (aff ));
541504}
542505} // namespace
543506
544507Scop::SyncLevel MappedScop::findBestSync (
545508 detail::ScheduleTree* st1,
546- detail::ScheduleTree* st2) {
509+ detail::ScheduleTree* st2,
510+ isl::multi_union_pw_aff domainToThread,
511+ isl::multi_union_pw_aff domainToWarp) {
547512 // Active points in the two schedule trees
548513 auto stRoot = scop_->scheduleRoot ();
549514 auto activePoints1 = activeDomainPointsBelow (stRoot, st1);
550515 auto activePoints2 = activeDomainPointsBelow (stRoot, st2);
551516
552517 // The dependences between the two schedule trees
553- auto dependences =
554- isl::union_map::from_domain_and_range (activePoints1, activePoints2 );
555- dependences = dependences.intersect (scop_-> dependences );
518+ auto dependences = scop_-> dependences ;
519+ dependences = dependences. intersect_domain (activePoints1);
520+ dependences = dependences.intersect_range (activePoints2 );
556521 if (dependences.is_empty ()) {
557522 return Scop::SyncLevel::None;
558523 }
559524
560- // The domain and the context of the root schedule tree
561- auto domainAndContext = scop_->domain ();
562525 CHECK_LE (1u , scop_->scheduleRoot ()->children ().size ());
563526 auto contextSt = scop_->scheduleRoot ()->children ()[0 ];
564527 auto contextElem = contextSt->elemAs <detail::ScheduleTreeElemContext>();
565528 CHECK (nullptr != contextElem);
566- domainAndContext = domainAndContext .intersect_params (contextElem->context_ );
529+ dependences = dependences .intersect_params (contextElem->context_ );
567530
568- // The domain of both schedule trees filtered by mapping filters,
569- // and then modified to have different threads and blocks names.
570- auto domain1 = intersectMappingFilterParams (st1, domainAndContext);
571- auto domain2 = intersectMappingFilterParams (st2, domainAndContext);
572- auto suffix1 = " _1" ;
573- auto suffix2 = " _2" ;
574- domain1 = modifyMappingNames (domain1, suffix1);
575- domain2 = modifyMappingNames (domain2, suffix2);
576-
577- // The dependences between the two schedule trees
578- // with mapping from threads and blocks
579- auto mappedDependences =
580- isl::union_map::from_domain_and_range (domain1, domain2);
581- mappedDependences = mappedDependences.intersect (dependences);
582-
583- auto space = mappedDependences.get_space ();
584- auto sameThreadConstraint =
585- getSameWarpConstraints (space, suffix1, suffix2, 1 , numThreads);
586- auto sameWarpConstraints =
587- getSameWarpConstraints (space, suffix1, suffix2, 32 , numThreads);
588-
589- if (mappedDependences ==
590- mappedDependences.intersect_params (sameThreadConstraint)) {
531+ if (dependences.is_subset (dependences.eq_at (domainToThread))) {
591532 return Scop::SyncLevel::None;
592- } else if (
593- mappedDependences ==
594- mappedDependences.intersect_params (sameWarpConstraints)) {
533+ }
534+ if (dependences.is_subset (dependences.eq_at (domainToWarp))) {
595535 return Scop::SyncLevel::Warp;
596536 }
597537 return Scop::SyncLevel::Block;
@@ -754,6 +694,10 @@ void MappedScop::insertBestSyncInSeq(detail::ScheduleTree* seq) {
754694
755695 auto outer = hasOuterSequentialMember (scop_->scheduleRoot (), seq);
756696
697+ auto domainToThread = extractDomainToThread (seq, numThreads.view .size ());
698+ auto threadToWarp = constructThreadToWarp (seq->ctx_ , 32 , numThreads);
699+ auto domainToWarp = domainToThread.apply (threadToWarp);
700+
757701 std::vector<std::vector<int >> bestSync (
758702 nChildren, std::vector<int >(nChildren + 1 ));
759703 // Get the synchronization needed between children[i] and children[i+k]
@@ -765,7 +709,8 @@ void MappedScop::insertBestSyncInSeq(detail::ScheduleTree* seq) {
765709 for (size_t i = 0 ; i < nChildren; ++i) {
766710 for (size_t k = 0 ; k < nChildren; ++k) {
767711 auto ik = (i + k) % nChildren;
768- bestSync[i][k] = (int )findBestSync (children[i], children[ik]);
712+ bestSync[i][k] = (int )findBestSync (
713+ children[i], children[ik], domainToThread, domainToWarp);
769714 }
770715 }
771716
0 commit comments