@@ -131,6 +131,8 @@ std::vector<T> collectBranchMarkers(T root, T node) {
131131 return findThreadSpecificMarkers (node);
132132}
133133
134+ struct FullSchedule ;
135+
134136/*
135137 * Transform schedule bands into a union_map.
136138 * Takes all partial schedules at leaves as MUPAs (without accounting for
@@ -139,7 +141,8 @@ std::vector<T> collectBranchMarkers(T root, T node) {
139141 * current leaves and transforms them into union maps.
140142 * Mapping filters are ignored.
141143 */
142- isl::union_map fullSchedule (const detail::ScheduleTree* root) {
144+ isl::UnionMap<Domain, FullSchedule> fullSchedule (
145+ const detail::ScheduleTree* root) {
143146 using namespace tc ::polyhedral::detail;
144147
145148 if (!root->elemAs <ScheduleTreeElemDomain>()) {
@@ -182,7 +185,7 @@ isl::union_map fullSchedule(const detail::ScheduleTree* root) {
182185 throw promotion::PromotionLogicError (ss.str ());
183186 }
184187 }
185- return schedule;
188+ return isl::UnionMap<Domain, FullSchedule>( schedule) ;
186189}
187190
188191/*
@@ -263,7 +266,7 @@ bool promotionImprovesCoalescing(
263266 const detail::ScheduleTree* root,
264267 const detail::ScheduleTree* node,
265268 const TensorReferenceGroup& group,
266- isl::union_map schedule) {
269+ isl::UnionMap<Domain, FullSchedule> schedule) {
267270 auto originalAccesses = group.originalAccesses ();
268271
269272 auto markers = collectBranchMarkers (root, node);
@@ -313,6 +316,8 @@ isl::union_set collectMappingsTo(const Scop& scop) {
313316 return mapping;
314317}
315318
319+ struct Unrolled ;
320+
316321/*
317322 * Check that only unrolled loops may appear in access subscripts.
318323 * Because the scoping point can be above a branching tree, descend into each
@@ -343,11 +348,12 @@ isl::union_set collectMappingsTo(const Scop& scop) {
343348 * different references may have different values, but all of them remain
344349 * independent of non-unrolled loop iterators.
345350 */
351+ template <typename Outer>
346352bool accessSubscriptsAreUnrolledLoops (
347353 const TensorReferenceGroup& group,
348354 const detail::ScheduleTree* root,
349355 const detail::ScheduleTree* scope,
350- isl::multi_union_pw_aff outerSchedule) {
356+ isl::MultiUnionPwAff<Domain, Outer> outerSchedule) {
351357 using namespace detail ;
352358
353359 auto nodes = ScheduleTree::collect (scope);
@@ -365,7 +371,7 @@ bool accessSubscriptsAreUnrolledLoops(
365371 auto subdomain = activeDomainPointsBelow (root, leaf);
366372
367373 auto unrolledDims = isl::union_pw_aff_list (leaf->ctx_ , 1 );
368- for (auto node : ancestors) {
374+ for (const detail::ScheduleTree* node : ancestors) {
369375 auto band = node->elemAs <detail::ScheduleTreeElemBand>();
370376 if (!band) {
371377 continue ;
@@ -383,7 +389,8 @@ bool accessSubscriptsAreUnrolledLoops(
383389
384390 auto space = isl::space (leaf->ctx_ , 0 , unrolledDims.n ())
385391 .align_params (subdomain.get_space ());
386- auto unrolledDimsMupa = isl::multi_union_pw_aff (space, unrolledDims);
392+ auto unrolledDimsMupa =
393+ isl::MultiUnionPwAff<Domain, Unrolled>(space, unrolledDims);
387394
388395 // It is possible that no loops are unrolled, in which case
389396 // unrolledDimsMupa is zero-dimensional and needs an explicit domain
@@ -392,10 +399,11 @@ bool accessSubscriptsAreUnrolledLoops(
392399 unrolledDimsMupa.intersect_domain (group.originalAccesses ().domain ());
393400
394401 auto accesses = group.originalAccesses ();
395- auto schedule = outerSchedule.flat_range_product (unrolledDimsMupa);
396- accesses = accesses.apply_domain (isl::union_map::from (schedule));
402+ auto schedule = outerSchedule.range_product (unrolledDimsMupa);
403+ auto scheduleMap = schedule.asUnionMap ();
404+ auto scheduledAccesses = accesses.apply_domain (scheduleMap);
397405
398- if (!accesses .is_single_valued ()) {
406+ if (!scheduledAccesses .is_single_valued ()) {
399407 return false ;
400408 }
401409 }
@@ -415,23 +423,25 @@ bool accessSubscriptsAreUnrolledLoops(
415423 * thread associated to a given pair of tensor element and outer schedule
416424 * iteration.
417425 */
426+ template <typename Outer>
418427bool isPromotableToRegistersBelow (
419428 const TensorReferenceGroup& group,
420429 const detail::ScheduleTree* root,
421430 const detail::ScheduleTree* scope,
422- isl::multi_union_pw_aff outer,
423- isl::multi_union_pw_aff thread) {
431+ isl::MultiUnionPwAff<Domain, Outer> outer,
432+ isl::MultiUnionPwAff<Domain, Thread> thread) {
424433 if (!accessSubscriptsAreUnrolledLoops (
425- group, root, scope, outer.flat_range_product (thread))) {
434+ group, root, scope, outer.range_product (thread))) {
426435 return false ;
427436 }
428437
429438 auto originalAccesses = group.originalAccesses ();
430- auto map = isl::union_map::from (outer);
431- map = map.range_product (originalAccesses);
432- map = map.apply_domain (isl::union_map::from (thread));
439+ auto outerMap = isl::UnionMap<Domain, Outer>::from (outer);
440+ auto pair = outerMap.range_product (originalAccesses);
441+ auto threadMap = isl::UnionMap<Domain, Thread>::from (thread);
442+ auto threadToPair = pair.apply_domain (threadMap);
433443
434- return map .is_injective ();
444+ return threadToPair .is_injective ();
435445}
436446
437447/*
@@ -654,15 +664,15 @@ void promoteToRegistersBelow(MappedScop& mscop, detail::ScheduleTree* scope) {
654664 auto blockSchedule = mscop.blockMappingSchedule (mscop.schedule ());
655665
656666 // Pure affine schedule without (mapping) filters.
657- auto partialSchedMupa = partialScheduleMupa (root, scope);
667+ auto partialSchedMupa = partialScheduleMupa<Scope> (root, scope);
658668 // Schedule with block mapping filter.
659669 auto partialSched =
660670 isl::union_map::from (partialSchedMupa).intersect_domain (blockMapping);
661671 // The following promotion validity and profitability checks need to be
662672 // performed with respect to the block mapping, so append the block schedule.
663673 // If the partial schedule contains it already, it will just end up with
664674 // identical dimensions without affecting the result of the checks.
665- partialSchedMupa = partialSchedMupa.flat_range_product (blockSchedule);
675+ auto partialSchedBlockMupa = partialSchedMupa.range_product (blockSchedule);
666676
667677 for (auto & tensorGroups : groupMap) {
668678 auto tensorId = tensorGroups.first ;
@@ -676,11 +686,11 @@ void promoteToRegistersBelow(MappedScop& mscop, detail::ScheduleTree* scope) {
676686 continue ;
677687 }
678688 if (!isPromotableToRegistersBelow (
679- *group, root, scope, partialSchedMupa , threadSchedule)) {
689+ *group, root, scope, partialSchedBlockMupa , threadSchedule)) {
680690 continue ;
681691 }
682692 // Check reuse within threads.
683- auto schedule = partialSchedMupa .flat_range_product (threadSchedule);
693+ auto schedule = partialSchedBlockMupa .flat_range_product (threadSchedule);
684694 if (!hasReuseWithin (*group, schedule)) {
685695 continue ;
686696 }
0 commit comments