@@ -437,14 +437,20 @@ bool isPromotableToRegistersBelow(
437437}
438438
439439/*
440- * Starting from the root, find bands where depth is reached. Using
440+ * Starting from the root, find bands where depth is reached. If zero depth is
441+ * requested, insert a zero-dimensional band node below the root (or the
442+ * context node if present) and return it. Otherwise, use
441443 * DFSPreorder to make sure order is specified and consistent for tests.
442444 */
443445std::vector<detail::ScheduleTree*> bandsContainingScheduleDepth (
444446 detail::ScheduleTree* root,
445447 size_t depth) {
446448 using namespace tc ::polyhedral::detail;
447449
450+ if (depth == 0 ) {
451+ return {insertTopLevelEmptyBand (root)};
452+ }
453+
448454 auto bands =
449455 ScheduleTree::collectDFSPreorder (root, detail::ScheduleTreeType::Band);
450456 std::function<bool (ScheduleTree * st)> containsDepth = [&](ScheduleTree* st) {
@@ -602,6 +608,15 @@ void promoteToSharedGreedy(
602608 scop.insertSyncsAroundCopies (bandNode);
603609 }
604610}
611+
612+ /*
613+ * Check if "tree" is a band node mapped to threads. In particular, check that
614+ * "tree" is a band and a thread-specific node appears as its only child.
615+ */
616+ inline bool isThreadMappedBand (const detail::ScheduleTree* tree) {
617+ return matchOne (band (threadSpecific (any ())), tree) ||
618+ matchOne (band (threadSpecific ()), tree);
619+ }
605620} // namespace
606621
607622void promoteGreedilyAtDepth (
@@ -698,16 +713,69 @@ void promoteToRegistersBelow(MappedScop& mscop, detail::ScheduleTree* scope) {
698713 partialSched);
699714 }
700715 }
716+
717+ // Return immediately if nothing was promoted.
718+ if (scope->numChildren () == 0 ||
719+ !matchOne (extension (sequence (any ())), scope->child ({0 }))) {
720+ return ;
721+ }
722+
723+ // If promoting above thread mapping, insert synchronizations.
724+ // It is possible that promoted array elements are accessed by different
725+ // threads outside the current scope (either in different iterations of the
726+ // scope loops, or in sibling subtrees). For now, always insert
727+ // synchronizations, similarly to copies to shared memory.
728+ //
729+ // TODO: The exact check for sync insertion requires the dependences between
730+ // the elements in the scope and those before/after the scope and a check if
731+ // the dependent instances belong to the same thread.
732+ auto ancestors = scope->ancestors (root);
733+ if (functional::Filter (isMappingTo<mapping::ThreadId>, ancestors).empty ()) {
734+ scop.insertSyncsAroundSeqChildren (scope->child ({0 , 0 }));
735+ }
701736}
702737
703- // Promote at the positions of the thread specific markers.
704- void promoteToRegistersBelowThreads (MappedScop& mscop, size_t nRegisters) {
705- auto & scop = mscop.scop ();
706- auto root = scop.scheduleRoot ();
707- auto markers = findThreadSpecificMarkers (root);
738+ /*
739+ * Promote to registers below "depth" schedule dimensions. Split bands if
740+ * necessary to create promotion scopes. Do not promote if it would require
741+ * splitting the band mapped to threads as we assume only one band can be
742+ * mapped.
743+ */
744+ void promoteToRegistersAtDepth (MappedScop& mscop, size_t depth) {
745+ using namespace detail ;
708746
709- for (auto marker : markers) {
710- promoteToRegistersBelow (mscop, marker);
747+ auto root = mscop.scop ().scheduleRoot ();
748+
749+ // 1. Collect all bands with a member located at the given depth in the
750+ // overall schedule. Make sure this is the last member of the band by
751+ // splitting off the subsequent members into a different band. Ignore bands
752+ // mapped to threads if splitting is required as it would break the invariant
753+ // of a single band being mapped to threads in a subtree.
754+ // TODO: allow splitting the thread-mapped bands; for example, tile them
755+ // explicitly with block size, use the point loops for thread mapping
756+ // but ignore them in depth computation.
757+ auto bands = bandsContainingScheduleDepth (root, depth);
758+ bands = functional::Filter (
759+ [root, depth](ScheduleTree* tree) {
760+ auto band = tree->elemAs <ScheduleTreeElemBand>();
761+ return !isThreadMappedBand (tree) ||
762+ tree->scheduleDepth (root) + band->nMember () == depth;
763+ },
764+ bands);
765+ bands = bandsSplitAfterDepth (bands, root, depth);
766+
767+ // 2. We don't want copies inserted between thread-mapped bands and the
768+ // thread-specific marker, but rather below that marker. If any of the bands
769+ // are mapped to threads, take their first children as promotion scope
770+ // instead of the band itself.
771+ std::function<ScheduleTree*(ScheduleTree*)> findScope =
772+ [](ScheduleTree* tree) {
773+ return isThreadMappedBand (tree) ? tree->child ({0 }) : tree;
774+ };
775+ auto scopes = functional::Map (findScope, bands);
776+
777+ for (auto scope : scopes) {
778+ promoteToRegistersBelow (mscop, scope);
711779 }
712780}
713781
0 commit comments