@@ -99,23 +99,22 @@ isl::union_map partialSchedule(
9999 return partialScheduleImpl (root, node, true );
100100}
101101
102- // Get a set of domain elements that are active at the given node.
102+ namespace {
103+ // Get a set of domain elements that are active below
104+ // the given branch of nodes.
103105//
104106// Domain elements are introduced by the root domain node. Filter nodes
105107// disable the points that do not intersect with the filter. Extension nodes
106108// are considered to introduce additional domain points.
107- isl::union_set activeDomainPoints (
109+ isl::union_set activeDomainPointsHelper (
108110 const ScheduleTree* root,
109- const ScheduleTree* node ) {
111+ const vector< const ScheduleTree*>& nodes ) {
110112 auto domainElem = root->elemAs <ScheduleTreeElemDomain>();
111113 CHECK (domainElem) << " root must be a Domain node" << *root;
112114
113115 auto domain = domainElem->domain_ ;
114- if (root == node) {
115- return domain;
116- }
117116
118- for (auto anc : node-> ancestors (root) ) {
117+ for (auto anc : nodes ) {
119118 if (auto filterElem = anc->elemAsBase <ScheduleTreeElemFilter>()) {
120119 domain = domain.intersect (filterElem->filter_ );
121120 } else if (auto extensionElem = anc->elemAs <ScheduleTreeElemExtension>()) {
@@ -134,6 +133,21 @@ isl::union_set activeDomainPoints(
134133 }
135134 return domain;
136135}
136+ } // namespace
137+
138+ isl::union_set activeDomainPoints (
139+ const ScheduleTree* root,
140+ const ScheduleTree* node) {
141+ return activeDomainPointsHelper (root, node->ancestors (root));
142+ }
143+
144+ isl::union_set activeDomainPointsBelow (
145+ const ScheduleTree* root,
146+ const ScheduleTree* node) {
147+ auto ancestors = node->ancestors (root);
148+ ancestors.emplace_back (node);
149+ return activeDomainPointsHelper (root, ancestors);
150+ }
137151
138152vector<ScheduleTree*> collectScheduleTreesPath (
139153 std::function<ScheduleTree*(ScheduleTree*)> next,
@@ -473,8 +487,16 @@ void updateTopLevelContext(detail::ScheduleTree* root, isl::set context) {
473487 contextElem->context_ = contextElem->context_ & context;
474488}
475489
476- ScheduleTree* insertSequenceAbove (ScheduleTree* root, ScheduleTree* tree) {
477- auto parent = tree->ancestor (root, 1 );
490+ namespace {
491+
492+ // In a tree starting at "root", insert a sequence node with
493+ // as only child the node identified by "tree"
494+ // within the subtree at "relativeRoot".
495+ ScheduleTree* insertSequenceAbove (
496+ const ScheduleTree* root,
497+ ScheduleTree* relativeRoot,
498+ ScheduleTree* tree) {
499+ auto parent = tree->ancestor (relativeRoot, 1 );
478500 auto childPos = tree->positionInParent (parent);
479501 auto filter = activeDomainPoints (root, tree).universe ();
480502 parent->insertChild (
@@ -484,11 +506,27 @@ ScheduleTree* insertSequenceAbove(ScheduleTree* root, ScheduleTree* tree) {
484506 return parent->child ({childPos});
485507}
486508
509+ } // namespace
510+
511+ ScheduleTree* insertSequenceAbove (ScheduleTree* root, ScheduleTree* tree) {
512+ return insertSequenceAbove (root, root, tree);
513+ }
514+
515+ void insertSequenceBelow (
516+ const detail::ScheduleTree* root,
517+ detail::ScheduleTree* tree) {
518+ auto numChildren = tree->numChildren ();
519+ CHECK_LE (numChildren, 1u );
520+ auto filter = activeDomainPointsBelow (root, tree).universe ();
521+ auto node = ScheduleTree::makeFilter (filter, tree->detachChildren ());
522+ tree->appendChild (ScheduleTree::makeSequence (std::move (node)));
523+ }
524+
487525ScheduleTree* insertExtensionAbove (
488- ScheduleTree* root ,
526+ ScheduleTree* relativeRoot ,
489527 ScheduleTree* tree,
490528 isl::union_map extension) {
491- auto parent = tree->ancestor (root , 1 );
529+ auto parent = tree->ancestor (relativeRoot , 1 );
492530 auto childPos = tree->positionInParent (parent);
493531 auto child = parent->detachChild (childPos);
494532 parent->insertChild (
@@ -500,85 +538,153 @@ namespace {
500538/*
501539 * Insert an empty extension node above "st" in a tree with the given root and
502540 * return a pointer to the inserted extension node.
541+ * The modification is performed within the subtree at "relativeRoot".
503542 */
504543detail::ScheduleTree* insertEmptyExtensionAbove (
505- ScheduleTree* root,
544+ const ScheduleTree* root,
545+ ScheduleTree* relativeRoot,
506546 ScheduleTree* st) {
507547 auto domain = root->elemAs <ScheduleTreeElemDomain>();
508548 CHECK (domain);
509549 auto space = domain->domain_ .get_space ();
510550 auto extension = isl::union_map::empty (space);
511- return insertExtensionAbove (root , st, extension);
551+ return insertExtensionAbove (relativeRoot , st, extension);
512552}
513- } // namespace
514553
515- void insertExtensionLabelAt (
516- ScheduleTree* root,
554+ /*
555+ * Construct an extension map for a zero-dimensional statement
556+ * with the given identifier.
557+ */
558+ isl::map labelExtension (ScheduleTree* root, ScheduleTree* tree, isl::id id) {
559+ auto prefix = prefixScheduleMupa (root, tree);
560+ auto scheduleSpace = prefix.get_space ();
561+ auto space = scheduleSpace.params ().set_from_params ().set_tuple_id (
562+ isl::dim_type::set, id);
563+ auto extensionSpace = scheduleSpace.map_from_domain_and_range (space);
564+ return isl::map::universe (extensionSpace);
565+ }
566+
567+ /*
568+ * Construct a filter node for a zero-dimensional extension statement
569+ * with the given extension map.
570+ */
571+ ScheduleTreeUPtr labelFilterFromExtension (isl::map extension) {
572+ return detail::ScheduleTree::makeFilter (extension.range ());
573+ }
574+
575+ /*
576+ * Given a sequence node in the schedule tree, insert
577+ * an extension with the given extension map and extension filter node
578+ * before the child at position "pos".
579+ * If "pos" is equal to the number of children, then
580+ * the statement is added after the last child.
581+ * The modification is performed within the subtree at "relativeRoot".
582+ */
583+ void insertExtensionAt (
584+ const ScheduleTree* root,
585+ ScheduleTree* relativeRoot,
517586 ScheduleTree* seqNode,
518587 size_t pos,
519- isl::id id) {
520- auto extensionTree = seqNode->ancestor (root, 1 );
588+ isl::union_map extension,
589+ ScheduleTreeUPtr&& filterNode) {
590+ auto extensionTree = seqNode->ancestor (relativeRoot, 1 );
521591 auto extensionNode =
522592 extensionTree->elemAs <detail::ScheduleTreeElemExtension>();
523593 if (!extensionNode) {
524- extensionTree = insertEmptyExtensionAbove (root, seqNode);
594+ extensionTree = insertEmptyExtensionAbove (root, relativeRoot, seqNode);
525595 extensionNode = extensionTree->elemAs <detail::ScheduleTreeElemExtension>();
526596 }
527597 CHECK (extensionNode);
528598 CHECK (seqNode->elemAs <detail::ScheduleTreeElemSequence>());
529- auto prefix = prefixScheduleMupa (root, extensionTree);
530- auto scheduleSpace = prefix.get_space ();
531- auto space = scheduleSpace.params ().set_from_params ().set_tuple_id (
532- isl::dim_type::set, id);
533- auto extensionSpace = scheduleSpace.map_from_domain_and_range (space);
534- auto extension = isl::map::universe (extensionSpace);
535599 extensionNode->extension_ = extensionNode->extension_ .unite (extension);
536- auto filterNode = detail::ScheduleTree::makeFilter (extension.range ());
537600 seqNode->insertChild (pos, std::move (filterNode));
538601}
602+ } // namespace
539603
540- void insertExtensionLabelBefore (
541- ScheduleTree* root,
604+ void insertExtensionBefore (
605+ const ScheduleTree* root,
606+ ScheduleTree* relativeRoot,
542607 ScheduleTree* tree,
543- isl::id id) {
608+ isl::union_map extension,
609+ ScheduleTreeUPtr&& filterNode) {
544610 size_t pos;
545- auto parent = tree->ancestor (root , 1 );
611+ auto parent = tree->ancestor (relativeRoot , 1 );
546612 ScheduleTree* seqTree;
613+ if (tree->elemAs <detail::ScheduleTreeElemExtension>()) {
614+ tree = tree->child ({0 });
615+ parent = tree;
616+ }
547617 if (tree->elemAs <detail::ScheduleTreeElemSequence>()) {
548618 seqTree = tree;
549619 pos = 0 ;
550620 } else if (
551621 parent->elemAs <detail::ScheduleTreeElemFilter>() &&
552622 parent->ancestor (root, 1 )->elemAs <detail::ScheduleTreeElemSequence>()) {
553- seqTree = parent->ancestor (root , 1 );
623+ seqTree = parent->ancestor (relativeRoot , 1 );
554624 pos = parent->positionInParent (seqTree);
555625 } else {
556- seqTree = insertSequenceAbove (root, tree);
626+ seqTree = insertSequenceAbove (root, relativeRoot, tree);
557627 pos = 0 ;
558628 }
559- insertExtensionLabelAt (root, seqTree, pos, id);
629+ insertExtensionAt (
630+ root, relativeRoot, seqTree, pos, extension, std::move (filterNode));
560631}
561632
562- void insertExtensionLabelAfter (
563- ScheduleTree* root,
633+ void insertExtensionAfter (
634+ const ScheduleTree* root,
635+ ScheduleTree* relativeRoot,
564636 ScheduleTree* tree,
565- isl::id id) {
637+ isl::union_map extension,
638+ ScheduleTreeUPtr&& filterNode) {
566639 size_t pos;
567- auto parent = tree->ancestor (root , 1 );
640+ auto parent = tree->ancestor (relativeRoot , 1 );
568641 ScheduleTree* seqTree;
642+ if (tree->elemAs <detail::ScheduleTreeElemExtension>()) {
643+ tree = tree->child ({0 });
644+ parent = tree;
645+ }
569646 if (tree->elemAs <detail::ScheduleTreeElemSequence>()) {
570647 seqTree = tree;
571648 pos = tree->numChildren ();
572649 } else if (
573650 parent->elemAs <detail::ScheduleTreeElemFilter>() &&
574651 parent->ancestor (root, 1 )->elemAs <detail::ScheduleTreeElemSequence>()) {
575- seqTree = parent->ancestor (root , 1 );
652+ seqTree = parent->ancestor (relativeRoot , 1 );
576653 pos = parent->positionInParent (seqTree) + 1 ;
577654 } else {
578- seqTree = insertSequenceAbove (root, tree);
655+ seqTree = insertSequenceAbove (root, relativeRoot, tree);
579656 pos = 1 ;
580657 }
581- insertExtensionLabelAt (root, seqTree, pos, id);
658+ insertExtensionAt (
659+ root, relativeRoot, seqTree, pos, extension, std::move (filterNode));
660+ }
661+
662+ void insertExtensionLabelAt (
663+ ScheduleTree* root,
664+ ScheduleTree* seqNode,
665+ size_t pos,
666+ isl::id id) {
667+ auto extension = labelExtension (root, seqNode, id);
668+ auto filterNode = labelFilterFromExtension (extension);
669+ insertExtensionAt (root, root, seqNode, pos, extension, std::move (filterNode));
670+ }
671+
672+ void insertExtensionLabelBefore (
673+ ScheduleTree* root,
674+ ScheduleTree* tree,
675+ isl::id id) {
676+ auto extension = labelExtension (root, tree, id);
677+ auto filterNode = labelFilterFromExtension (extension);
678+ insertExtensionBefore (root, root, tree, extension, std::move (filterNode));
679+ }
680+
681+ void insertExtensionLabelAfter (
682+ ScheduleTree* root,
683+ ScheduleTree* tree,
684+ isl::id id) {
685+ auto extension = labelExtension (root, tree, id);
686+ auto filterNode = labelFilterFromExtension (extension);
687+ insertExtensionAfter (root, root, tree, extension, std::move (filterNode));
582688}
583689
584690namespace {
0 commit comments