@@ -507,8 +507,16 @@ void updateTopLevelContext(detail::ScheduleTree* root, isl::set context) {
507507 contextElem->context_ = contextElem->context_ & context;
508508}
509509
510- ScheduleTree* insertSequenceAbove (ScheduleTree* root, ScheduleTree* tree) {
511- auto parent = tree->ancestor (root, 1 );
510+ namespace {
511+
512+ // In a tree starting at "root", insert a sequence node with
513+ // as only child the node identified by "tree"
514+ // within the subtree at "relativeRoot".
515+ ScheduleTree* insertSequenceAbove (
516+ const ScheduleTree* root,
517+ ScheduleTree* relativeRoot,
518+ ScheduleTree* tree) {
519+ auto parent = tree->ancestor (relativeRoot, 1 );
512520 auto childPos = tree->positionInParent (parent);
513521 auto filter = activeDomainPoints (root, tree).universe ();
514522 parent->insertChild (
@@ -518,6 +526,12 @@ ScheduleTree* insertSequenceAbove(ScheduleTree* root, ScheduleTree* tree) {
518526 return parent->child ({childPos});
519527}
520528
529+ } // namespace
530+
531+ ScheduleTree* insertSequenceAbove (ScheduleTree* root, ScheduleTree* tree) {
532+ return insertSequenceAbove (root, root, tree);
533+ }
534+
521535void insertSequenceBelow (
522536 const detail::ScheduleTree* root,
523537 detail::ScheduleTree* tree) {
@@ -544,49 +558,77 @@ namespace {
544558/*
545559 * Insert an empty extension node above "st" in a tree with the given root and
546560 * return a pointer to the inserted extension node.
561+ * The modification is performed within the subtree at "relativeRoot".
547562 */
548563detail::ScheduleTree* insertEmptyExtensionAbove (
549- ScheduleTree* root,
564+ const ScheduleTree* root,
565+ ScheduleTree* relativeRoot,
550566 ScheduleTree* st) {
551567 auto domain = root->elemAs <ScheduleTreeElemDomain>();
552568 CHECK (domain);
553569 auto space = domain->domain_ .get_space ();
554570 auto extension = isl::union_map::empty (space);
555- return insertExtensionAbove (root , st, extension);
571+ return insertExtensionAbove (relativeRoot , st, extension);
556572}
557- } // namespace
558573
559- void insertExtensionLabelAt (
560- ScheduleTree* root,
574+ /*
575+ * Construct an extension map for a zero-dimensional statement
576+ * with the given identifier.
577+ */
578+ isl::map labelExtension (ScheduleTree* root, ScheduleTree* tree, isl::id id) {
579+ auto prefix = prefixScheduleMupa (root, tree);
580+ auto scheduleSpace = prefix.get_space ();
581+ auto space = scheduleSpace.params ().set_from_params ().set_tuple_id (
582+ isl::dim_type::set, id);
583+ auto extensionSpace = scheduleSpace.map_from_domain_and_range (space);
584+ return isl::map::universe (extensionSpace);
585+ }
586+
587+ /*
588+ * Construct a filter node for a zero-dimensional extension statement
589+ * with the given extension map.
590+ */
591+ ScheduleTreeUPtr labelFilterFromExtension (isl::map extension) {
592+ return detail::ScheduleTree::makeFilter (extension.range ());
593+ }
594+
595+ /*
596+ * Given a sequence node in the schedule tree, insert
597+ * an extension with the given extension map and extension filter node
598+ * before the child at position "pos".
599+ * If "pos" is equal to the number of children, then
600+ * the statement is added after the last child.
601+ * The modification is performed within the subtree at "relativeRoot".
602+ */
603+ void insertExtensionAt (
604+ const ScheduleTree* root,
605+ ScheduleTree* relativeRoot,
561606 ScheduleTree* seqNode,
562607 size_t pos,
563- isl::id id) {
564- auto extensionTree = seqNode->ancestor (root, 1 );
608+ isl::union_map extension,
609+ ScheduleTreeUPtr&& filterNode) {
610+ auto extensionTree = seqNode->ancestor (relativeRoot, 1 );
565611 auto extensionNode =
566612 extensionTree->elemAs <detail::ScheduleTreeElemExtension>();
567613 if (!extensionNode) {
568- extensionTree = insertEmptyExtensionAbove (root, seqNode);
614+ extensionTree = insertEmptyExtensionAbove (root, relativeRoot, seqNode);
569615 extensionNode = extensionTree->elemAs <detail::ScheduleTreeElemExtension>();
570616 }
571617 CHECK (extensionNode);
572618 CHECK (seqNode->elemAs <detail::ScheduleTreeElemSequence>());
573- auto prefix = prefixScheduleMupa (root, extensionTree);
574- auto scheduleSpace = prefix.get_space ();
575- auto space = scheduleSpace.params ().set_from_params ().set_tuple_id (
576- isl::dim_type::set, id);
577- auto extensionSpace = scheduleSpace.map_from_domain_and_range (space);
578- auto extension = isl::map::universe (extensionSpace);
579619 extensionNode->extension_ = extensionNode->extension_ .unite (extension);
580- auto filterNode = detail::ScheduleTree::makeFilter (extension.range ());
581620 seqNode->insertChild (pos, std::move (filterNode));
582621}
622+ } // namespace
583623
584- void insertExtensionLabelBefore (
585- ScheduleTree* root,
624+ void insertExtensionBefore (
625+ const ScheduleTree* root,
626+ ScheduleTree* relativeRoot,
586627 ScheduleTree* tree,
587- isl::id id) {
628+ isl::union_map extension,
629+ ScheduleTreeUPtr&& filterNode) {
588630 size_t pos;
589- auto parent = tree->ancestor (root , 1 );
631+ auto parent = tree->ancestor (relativeRoot , 1 );
590632 ScheduleTree* seqTree;
591633 if (tree->elemAs <detail::ScheduleTreeElemExtension>()) {
592634 tree = tree->child ({0 });
@@ -598,21 +640,24 @@ void insertExtensionLabelBefore(
598640 } else if (
599641 parent->elemAs <detail::ScheduleTreeElemFilter>() &&
600642 parent->ancestor (root, 1 )->elemAs <detail::ScheduleTreeElemSequence>()) {
601- seqTree = parent->ancestor (root , 1 );
643+ seqTree = parent->ancestor (relativeRoot , 1 );
602644 pos = parent->positionInParent (seqTree);
603645 } else {
604- seqTree = insertSequenceAbove (root, tree);
646+ seqTree = insertSequenceAbove (root, relativeRoot, tree);
605647 pos = 0 ;
606648 }
607- insertExtensionLabelAt (root, seqTree, pos, id);
649+ insertExtensionAt (
650+ root, relativeRoot, seqTree, pos, extension, std::move (filterNode));
608651}
609652
610- void insertExtensionLabelAfter (
611- ScheduleTree* root,
653+ void insertExtensionAfter (
654+ const ScheduleTree* root,
655+ ScheduleTree* relativeRoot,
612656 ScheduleTree* tree,
613- isl::id id) {
657+ isl::union_map extension,
658+ ScheduleTreeUPtr&& filterNode) {
614659 size_t pos;
615- auto parent = tree->ancestor (root , 1 );
660+ auto parent = tree->ancestor (relativeRoot , 1 );
616661 ScheduleTree* seqTree;
617662 if (tree->elemAs <detail::ScheduleTreeElemExtension>()) {
618663 tree = tree->child ({0 });
@@ -624,13 +669,42 @@ void insertExtensionLabelAfter(
624669 } else if (
625670 parent->elemAs <detail::ScheduleTreeElemFilter>() &&
626671 parent->ancestor (root, 1 )->elemAs <detail::ScheduleTreeElemSequence>()) {
627- seqTree = parent->ancestor (root , 1 );
672+ seqTree = parent->ancestor (relativeRoot , 1 );
628673 pos = parent->positionInParent (seqTree) + 1 ;
629674 } else {
630- seqTree = insertSequenceAbove (root, tree);
675+ seqTree = insertSequenceAbove (root, relativeRoot, tree);
631676 pos = 1 ;
632677 }
633- insertExtensionLabelAt (root, seqTree, pos, id);
678+ insertExtensionAt (
679+ root, relativeRoot, seqTree, pos, extension, std::move (filterNode));
680+ }
681+
682+ void insertExtensionLabelAt (
683+ ScheduleTree* root,
684+ ScheduleTree* seqNode,
685+ size_t pos,
686+ isl::id id) {
687+ auto extension = labelExtension (root, seqNode, id);
688+ auto filterNode = labelFilterFromExtension (extension);
689+ insertExtensionAt (root, root, seqNode, pos, extension, std::move (filterNode));
690+ }
691+
692+ void insertExtensionLabelBefore (
693+ ScheduleTree* root,
694+ ScheduleTree* tree,
695+ isl::id id) {
696+ auto extension = labelExtension (root, tree, id);
697+ auto filterNode = labelFilterFromExtension (extension);
698+ insertExtensionBefore (root, root, tree, extension, std::move (filterNode));
699+ }
700+
701+ void insertExtensionLabelAfter (
702+ ScheduleTree* root,
703+ ScheduleTree* tree,
704+ isl::id id) {
705+ auto extension = labelExtension (root, tree, id);
706+ auto filterNode = labelFilterFromExtension (extension);
707+ insertExtensionAfter (root, root, tree, extension, std::move (filterNode));
634708}
635709
636710namespace {
0 commit comments