@@ -391,34 +391,6 @@ TensorReferenceGroup::referenceIds() const {
391391}
392392
393393namespace {
394- bool hasCopyExtensionSingleChild (const ScheduleTree* tree) {
395- if (tree->numChildren () != 1 ) {
396- return false ;
397- }
398-
399- auto extensionNode =
400- tree->child ({0 })->elemAs <detail::ScheduleTreeElemExtension>();
401- if (!extensionNode) {
402- return false ;
403- }
404-
405- if ((tree->child ({0 })->numChildren () != 1 ) &&
406- (tree->child ({0 , 0 })->elemAs <detail::ScheduleTreeElemSequence>())) {
407- return false ;
408- }
409-
410- for (auto e : isl::UnionAsVector<isl::union_map>(extensionNode->extension_ )) {
411- if (!e.has_tuple_name (isl::dim_type::out)) {
412- return false ;
413- }
414- if (e.get_tuple_name (isl::dim_type::out) != kReadIdName &&
415- e.get_tuple_name (isl::dim_type::out) != kWriteIdName ) {
416- return false ;
417- }
418- }
419- return true ;
420- }
421-
422394// Construct the set containing all tensor elements.
423395//
424396// Find the Halide image corresponding to the given tensorId. Transform its
@@ -524,48 +496,26 @@ ScheduleTree* insertCopiesUnder(
524496 bool reads = !group.scopedReads ().is_empty ();
525497 bool writes = !group.scopedWrites ().is_empty ();
526498
527- if (hasCopyExtensionSingleChild (tree)) {
528- auto extensionNode = tree->child ({0 });
529- auto sequenceNode = tree->child ({0 , 0 });
530-
531- auto & ext =
532- extensionNode->elemAs <detail::ScheduleTreeElemExtension>()->extension_ ;
533- if (reads) {
534- ext = ext.unite (isl::union_map (readExtension));
535- sequenceNode->insertChild (0 , std::move (readFilterNode));
536- }
537- if (writes) {
538- ext = ext.unite (isl::union_map (writeExtension));
539- sequenceNode->appendChild (std::move (writeFilterNode));
540- }
541- return tree;
499+ if (tree->numChildren () == 0 ) {
500+ // The point underneath a leaf node cannot be referenced,
501+ // so insert a dummy sequence first. It will be extended
502+ // with the reads and/or writes.
503+ insertSequenceBelow (root, tree);
542504 }
543505
544- auto mainCompFilter = activeDomainPoints (root, tree).universe ();
545- auto mainCompFilterNode =
546- ScheduleTree::makeFilter (mainCompFilter, tree->detachChildren ());
547-
548- // XXX: I don't really like the syntax-imposed impossibility to create a
549- // sequence node with no children.
550- auto sequenceNode = ScheduleTree::makeSequence (
551- reads ? std::move (readFilterNode) : std::move (mainCompFilterNode));
552506 if (reads) {
553- sequenceNode->appendChild (std::move (mainCompFilterNode));
507+ insertExtensionBefore (
508+ root, tree, tree->child ({0 }), readExtension, std::move (readFilterNode));
554509 }
555510 if (writes) {
556- sequenceNode->appendChild (std::move (writeFilterNode));
511+ insertExtensionAfter (
512+ root,
513+ tree,
514+ tree->child ({0 }),
515+ writeExtension,
516+ std::move (writeFilterNode));
557517 }
558518
559- auto extensionUmap = isl::union_map::empty (promotionSpace.params ());
560- if (reads) {
561- extensionUmap = extensionUmap.unite (readExtension);
562- }
563- if (writes) {
564- extensionUmap = extensionUmap.unite (writeExtension);
565- }
566- auto extensionNode =
567- ScheduleTree::makeExtension (extensionUmap, std::move (sequenceNode));
568- tree->appendChild (std::move (extensionNode));
569519 return tree;
570520}
571521} // namespace polyhedral
0 commit comments