3434#include " tc/core/polyhedral/mapping_types.h"
3535#include " tc/core/polyhedral/schedule_tree_elem.h"
3636#include " tc/core/polyhedral/schedule_tree_matcher.h"
37+ #include " tc/core/polyhedral/schedule_utils.h"
3738#include " tc/core/scope_guard.h"
3839#include " tc/external/isl.h"
3940
@@ -47,164 +48,6 @@ namespace tc {
4748namespace polyhedral {
4849using namespace detail ;
4950
50- isl::union_map extendSchedule (
51- const ScheduleTree* node,
52- isl::union_map schedule) {
53- if (auto bandElem = node->elemAs <ScheduleTreeElemBand>()) {
54- if (bandElem->nMember () > 0 ) {
55- schedule =
56- schedule.flat_range_product (isl::union_map::from (bandElem->mupa_ ));
57- }
58- } else if (auto filterElem = node->elemAs <ScheduleTreeElemFilter>()) {
59- schedule = schedule.intersect_domain (filterElem->filter_ );
60- } else if (auto extensionElem = node->elemAs <ScheduleTreeElemExtension>()) {
61- // FIXME: we may need to restrict the range of reversed extension map to
62- // schedule values that correspond to active domain elements at this
63- // point.
64- schedule = schedule.unite (
65- extensionElem->extension_ .reverse ().intersect_range (schedule.range ()));
66- }
67-
68- return schedule;
69- }
70-
71- namespace {
72- isl::union_map partialScheduleImpl (
73- const ScheduleTree* root,
74- const ScheduleTree* node,
75- bool useNode) {
76- auto nodes = node->ancestors (root);
77- if (useNode) {
78- nodes.push_back (node);
79- }
80- TC_CHECK_GT (nodes.size (), 0u ) << " root node does not have a prefix schedule" ;
81- auto domain = root->elemAs <ScheduleTreeElemDomain>();
82- TC_CHECK (domain);
83- auto schedule = isl::union_map::from_domain (domain->domain_ );
84- for (auto anc : nodes) {
85- if (anc->elemAs <ScheduleTreeElemDomain>()) {
86- TC_CHECK (anc == root);
87- } else {
88- schedule = extendSchedule (anc, schedule);
89- }
90- }
91- return schedule;
92- }
93- } // namespace
94-
95- isl::union_map prefixSchedule (
96- const ScheduleTree* root,
97- const ScheduleTree* node) {
98- return partialScheduleImpl (root, node, false );
99- }
100-
101- isl::union_map partialSchedule (
102- const ScheduleTree* root,
103- const ScheduleTree* node) {
104- return partialScheduleImpl (root, node, true );
105- }
106-
107- namespace {
108- /*
109- * If "node" is any filter, then intersect "domain" with that filter.
110- */
111- isl::union_set applyFilter (isl::union_set domain, const ScheduleTree* node) {
112- if (auto filterElem = node->elemAs <ScheduleTreeElemFilter>()) {
113- return domain.intersect (filterElem->filter_ );
114- }
115- return domain;
116- }
117-
118- /*
119- * If "node" is a mapping, then intersect "domain" with its filter.
120- */
121- isl::union_set applyMapping (isl::union_set domain, const ScheduleTree* node) {
122- if (auto filterElem = node->elemAs <ScheduleTreeElemMapping>()) {
123- return domain.intersect (filterElem->filter_ );
124- }
125- return domain;
126- }
127-
128- // Get the set of domain elements that are active below
129- // the given branch of nodes, filtered using "filter".
130- //
131- // Domain elements are introduced by the root domain node. Some nodes
132- // refine this set of elements based on "filter". Extension nodes
133- // are considered to introduce additional domain points.
134- isl::union_set collectDomain (
135- const ScheduleTree* root,
136- const vector<const ScheduleTree*>& nodes,
137- isl::union_set (*filter)(isl::union_set domain, const ScheduleTree* node)) {
138- auto domainElem = root->elemAs <ScheduleTreeElemDomain>();
139- TC_CHECK (domainElem) << " root must be a Domain node" << *root;
140-
141- auto domain = domainElem->domain_ ;
142-
143- for (auto anc : nodes) {
144- domain = filter (domain, anc);
145- if (auto extensionElem = anc->elemAs <ScheduleTreeElemExtension>()) {
146- auto parentSchedule = prefixSchedule (root, anc);
147- auto extension = extensionElem->extension_ ;
148- TC_CHECK (parentSchedule) << " missing root domain node" ;
149- parentSchedule = parentSchedule.intersect_domain (domain);
150- domain = domain.unite (parentSchedule.range ().apply (extension));
151- }
152- }
153- return domain;
154- }
155-
156- // Get the set of domain elements that are active below
157- // the given branch of nodes.
158- isl::union_set activeDomainPointsHelper (
159- const ScheduleTree* root,
160- const vector<const ScheduleTree*>& nodes) {
161- return collectDomain (root, nodes, &applyFilter);
162- }
163-
164- } // namespace
165-
166- isl::union_set prefixMappingFilter (
167- const ScheduleTree* root,
168- const ScheduleTree* node) {
169- return collectDomain (root, node->ancestors (root), &applyMapping);
170- }
171-
172- isl::union_set activeDomainPoints (
173- const ScheduleTree* root,
174- const ScheduleTree* node) {
175- return activeDomainPointsHelper (root, node->ancestors (root));
176- }
177-
178- isl::union_set activeDomainPointsBelow (
179- const ScheduleTree* root,
180- const ScheduleTree* node) {
181- auto ancestors = node->ancestors (root);
182- ancestors.emplace_back (node);
183- return activeDomainPointsHelper (root, ancestors);
184- }
185-
186- vector<ScheduleTree*> collectScheduleTreesPath (
187- std::function<ScheduleTree*(ScheduleTree*)> next,
188- ScheduleTree* start) {
189- vector<ScheduleTree*> res{start};
190- auto n = start;
191- while ((n = next (n)) != nullptr ) {
192- res.push_back (n);
193- }
194- return res;
195- }
196-
197- vector<const ScheduleTree*> collectScheduleTreesPath (
198- std::function<const ScheduleTree*(const ScheduleTree*)> next,
199- const ScheduleTree* start) {
200- vector<const ScheduleTree*> res{start};
201- auto n = start;
202- while ((n = next (n)) != nullptr ) {
203- res.push_back (n);
204- }
205- return res;
206- }
207-
20851// Replace "tree" in the list of its parent's children with newTree.
20952// Returns the pointer to newTree for call chaining purposes.
21053ScheduleTree* swapSubtree (
@@ -432,85 +275,6 @@ ScheduleTree* bandScale(ScheduleTree* tree, const vector<size_t>& scales) {
432275 return tree;
433276}
434277
435- namespace {
436-
437- template <typename T>
438- vector<T> reversed (const vector<T>& vec) {
439- vector<T> result;
440- result.reserve (vec.size ());
441- result.insert (result.begin (), vec.rbegin (), vec.rend ());
442- return result;
443- }
444-
445- template <typename T>
446- vector<const ScheduleTree*> filterType (const vector<const ScheduleTree*>& vec) {
447- vector<const ScheduleTree*> result;
448- for (auto e : vec) {
449- if (e->elemAs <T>()) {
450- result.push_back (e);
451- }
452- }
453- return result;
454- }
455-
456- template <typename T, typename Func>
457- T foldl (const vector<const ScheduleTree*> vec, Func op, T init = T()) {
458- T value = init;
459- for (auto st : vec) {
460- value = op (st, value);
461- }
462- return value;
463- }
464-
465- template <typename ... Args>
466- ostream& operator <<(ostream& os, const vector<Args...>& v) {
467- os << " [" ;
468- bool first = true ;
469- for (auto const & ve : v) {
470- if (!first) {
471- os << " , " ;
472- }
473- os << ve;
474- first = true ;
475- }
476- os << " ]" ;
477- return os;
478- }
479- } // namespace
480-
481- isl::multi_union_pw_aff infixScheduleMupa (
482- const ScheduleTree* root,
483- const ScheduleTree* relativeRoot,
484- const ScheduleTree* tree) {
485- auto domainElem = root->elemAs <ScheduleTreeElemDomain>();
486- TC_CHECK (domainElem);
487- auto domain = domainElem->domain_ .universe ();
488- auto zero = isl::multi_val::zero (domain.get_space ().set_from_params ());
489- auto prefix = isl::multi_union_pw_aff (domain, zero);
490- prefix = foldl (
491- filterType<ScheduleTreeElemBand>(tree->ancestors (relativeRoot)),
492- [](const ScheduleTree* st, isl::multi_union_pw_aff pref) {
493- auto mupa = st->elemAs <ScheduleTreeElemBand>()->mupa_ ;
494- return pref.flat_range_product (mupa);
495- },
496- prefix);
497- return prefix;
498- }
499-
500- isl::multi_union_pw_aff prefixScheduleMupa (
501- const ScheduleTree* root,
502- const ScheduleTree* tree) {
503- return infixScheduleMupa (root, root, tree);
504- }
505-
506- isl::multi_union_pw_aff partialScheduleMupa (
507- const detail::ScheduleTree* root,
508- const detail::ScheduleTree* tree) {
509- auto prefix = prefixScheduleMupa (root, tree);
510- auto band = tree->elemAs <ScheduleTreeElemBand>();
511- return band ? prefix.flat_range_product (band->mupa_ ) : prefix;
512- }
513-
514278void updateTopLevelContext (detail::ScheduleTree* root, isl::set context) {
515279 if (!matchOne (tc::polyhedral::domain (tc::polyhedral::context (any ())), root)) {
516280 root->appendChild (ScheduleTree::makeContext (
@@ -832,55 +596,5 @@ void orderAfter(ScheduleTree* root, ScheduleTree* tree, isl::union_set filter) {
832596 seq->insertChild (0 , gistedFilter (other, parent->detachChild (childPos)));
833597 parent->insertChild (childPos, std::move (seq));
834598}
835-
836- /*
837- * Extract a mapping from the domain elements active at "tree"
838- * to identifiers "ids", where all branches in "tree"
839- * are assumed to have been mapped to these identifiers.
840- * The result lives in a space of the form "tupleId"["ids"...].
841- */
842- isl::multi_union_pw_aff extractDomainToIds (
843- const detail::ScheduleTree* root,
844- const detail::ScheduleTree* tree,
845- const std::vector<mapping::MappingId>& ids,
846- isl::id tupleId) {
847- using namespace polyhedral ::detail;
848-
849- auto space = isl::space (tree->ctx_ , 0 );
850- auto empty = isl::union_set::empty (space);
851- space = space.named_set_from_params_id (tupleId, ids.size ());
852- auto zero = isl::multi_val::zero (space);
853- auto domainToIds = isl::multi_union_pw_aff (empty, zero);
854-
855- for (auto mapping : tree->collect (tree, ScheduleTreeType::Mapping)) {
856- auto mappingNode = mapping->elemAs <ScheduleTreeElemMapping>();
857- auto list = isl::union_pw_aff_list (tree->ctx_ , ids.size ());
858- for (auto id : ids) {
859- if (mappingNode->mapping .count (id) == 0 ) {
860- break ;
861- }
862- auto idMap = mappingNode->mapping .at (id);
863- list = list.add (idMap);
864- }
865- // Ignore this node if it does not map to all required ids.
866- if (static_cast <size_t >(list.size ()) != ids.size ()) {
867- continue ;
868- }
869- auto nodeToIds = isl::multi_union_pw_aff (space, list);
870- auto active = activeDomainPoints (root, mapping);
871- TC_CHECK (active.intersect (domainToIds.domain ()).is_empty ())
872- << " conflicting mappings; are the filters in the tree disjoint?" ;
873- nodeToIds = nodeToIds.intersect_domain (active);
874- domainToIds = domainToIds.union_add (nodeToIds);
875- }
876-
877- auto active = activeDomainPoints (root, tree);
878- TC_CHECK (active.is_subset (domainToIds.domain ()))
879- << " not all domain points of\n "
880- << active << " \n were mapped to the required ids" ;
881-
882- return domainToIds;
883- }
884-
885599} // namespace polyhedral
886600} // namespace tc
0 commit comments