@@ -157,6 +157,35 @@ std::vector<T> findThreadSpecificMarkers(T root) {
157157 root, ScheduleTreeType::ThreadSpecificMarker);
158158}
159159
160+ /*
161+ * Return the thread specific markers in the tree rooted at "root"
162+ * that are relevant for "node".
163+ *
164+ * Every branch in the tree has exactly one thread marker.
165+ * If "node" appears underneath a thread marker, then return
166+ * that single thread marker.
167+ * Otherwise, return the (possibly multiple) thread markers
168+ * in the subtree rooted at "node".
169+ */
170+ template <typename T>
171+ std::vector<T> collectBranchMarkers (T root, T node) {
172+ using namespace detail ;
173+ static_assert (
174+ std::is_convertible<T, const ScheduleTree*>::value,
175+ " expecting ScheduleTree" );
176+
177+ auto filterMarker = [](T tree) {
178+ return tree->type_ == ScheduleTreeType::ThreadSpecificMarker;
179+ };
180+
181+ auto ancestors = node->ancestors (root);
182+ ancestors = functional::Filter (filterMarker, ancestors);
183+ if (ancestors.size () > 0 ) {
184+ return ancestors;
185+ }
186+ return findThreadSpecificMarkers (node);
187+ }
188+
160189/*
161190 * Transform schedule bands into a union_map.
162191 * Takes all partial schedules at leaves as MUPAs (without accounting for
@@ -277,27 +306,6 @@ isl::map makeNextElementMap(isl::space setSpace, unsigned dim) {
277306 return isl::map (identityMA);
278307}
279308
280- // Obtain the depth of the schedule dimension that was mapped to threadIdx.x
281- // for the domain elements identified by "s". Assumes the depth is the same
282- // for all these elements.
283- size_t computeThreadIdxXScheduleDepth (
284- const ThreadIdxXScheduleDepthState& threadIdxXScheduleDepthState,
285- isl::union_set s) {
286- std::unordered_set<size_t > depths;
287- for (auto p : threadIdxXScheduleDepthState) {
288- if (!p.first .intersect (s).is_empty ()) {
289- depths.insert (p.second );
290- }
291- }
292- if (depths.size () != 1 ) {
293- std::stringstream ss;
294- ss << " threadIdx.x depth " << (depths.size () == 0 ? " unknown" : " diverged" )
295- << " for " << s;
296- throw promotion::PromotionLogicError (ss.str ());
297- }
298- return *depths.begin ();
299- }
300-
301309/*
302310 * Return the outermost thread mapping filter among the ancestors of "node",
303311 * assuming that there is at least one.
@@ -318,42 +326,49 @@ const detail::ScheduleTree* findThreadMappingAncestor(
318326 *
319327 * If the reference group is not already accessed in a coalesced way,
320328 * then the group should be promoted.
329+ * If a branch is mapped to a single thread, then the accesses
330+ * in that branch are not considered to contribute to the usefulness
331+ * of promoting.
332+ *
321333 * The check for coalesced accesses is performed as follows.
322334 * Check if incrementing the schedule dimension mapped to
323335 * Thread::x results in the last tensor index being incremented as well.
324336 * Since accesses in the group may belong to different statements, which may
325- * have different loops mapped to Thread::x, perform the check for each basic
326- * map in the union of access maps taking into account which dimension is
327- * mapped for a particular statement (domain of the basic map). The group is
337+ * have different loops mapped to Thread::x, perform the check for each thread
338+ * mapping on the statements active at "node" (either a single ancestor,
339+ * or one or more descendants).
340+ * The iteration over the spaces is used to handle the case where
341+ * one of the subbranches does not access the tensor and
342+ * the scheduled accesses are empty. The group is
328343 * accessed in a coalesced way if all references in this group are accessed in
329344 * a coalesced way.
330345 */
331346bool promotionImprovesCoalescing (
332- const ThreadIdxXScheduleDepthState& threadIdxXScheduleDepthState,
347+ const detail::ScheduleTree* root,
348+ const detail::ScheduleTree* node,
333349 const TensorReferenceGroup& group,
334- isl::union_map schedule,
335- isl::union_set activePoints) {
350+ isl::union_map schedule) {
336351 auto originalAccesses = group.originalAccesses ();
337352
338- for (auto accessMap : isl::UnionAsVector<isl::union_map>(originalAccesses)) {
339- for (auto access : accessMap.get_basic_map_list ()) {
353+ auto markers = collectBranchMarkers (root, node);
354+ for (auto marker : markers) {
355+ auto mapping = findThreadMappingAncestor (root, marker);
356+ size_t nMappedThreads = marker->scheduleDepth (mapping);
357+ if (nMappedThreads == 0 ) {
358+ continue ;
359+ }
360+ auto depth = marker->scheduleDepth (root);
361+ auto activePoints = activeDomainPoints (root, mapping);
362+ auto localAccesses = originalAccesses.intersect_domain (activePoints);
363+ auto scheduledAccesses = localAccesses.apply_domain (schedule);
364+ for (auto access : isl::UnionAsVector<isl::union_map>(scheduledAccesses)) {
365+ auto scheduleSpace = access.get_space ().domain ();
340366 auto tensorSpace = access.get_space ().range ();
341367 auto elementToNext = makeNextElementMap (
342368 tensorSpace, tensorSpace.dim (isl::dim_type::set) - 1 );
343- auto domainUMap = isl::union_set (isl::set (access.domain ()));
344- int threadIdxXDepth = computeThreadIdxXScheduleDepth (
345- threadIdxXScheduleDepthState, domainUMap.intersect (activePoints));
346- auto partialScheduleUMap =
347- schedule.intersect_domain (domainUMap.universe ());
348- if (partialScheduleUMap.n_map () != 1 ) {
349- throw promotion::PromotionLogicError (" expected single schedule space" );
350- }
351- auto partialSchedule = isl::map::from_union_map (partialScheduleUMap);
352- auto scheduleToNextX = makeNextElementMap (
353- partialSchedule.get_space ().range (), threadIdxXDepth);
354- auto scheduledAccess = isl::map (access).apply_domain (partialSchedule);
355- auto accessedByAdjacentX = scheduleToNextX.apply_domain (scheduledAccess)
356- .apply_range (scheduledAccess);
369+ auto scheduleToNextX = makeNextElementMap (scheduleSpace, depth - 1 );
370+ auto accessedByAdjacentX =
371+ scheduleToNextX.apply_domain (access).apply_range (access);
357372
358373 if (not accessedByAdjacentX.is_subset (elementToNext)) {
359374 return true ;
@@ -467,7 +482,6 @@ std::vector<detail::ScheduleTree*> bandsSplitAfterDepth(
467482 */
468483void promoteToSharedGreedy (
469484 Scop& scop,
470- const ThreadIdxXScheduleDepthState& threadIdxXScheduleDepthState,
471485 const Block& block,
472486 size_t depth,
473487 size_t maxMemory) {
@@ -561,11 +575,7 @@ void promoteToSharedGreedy(
561575 // Do not promote if the group features no reuse and is accessed in a
562576 // coalesced way.
563577 if (!hasReuseWithin (*group, partialSchedMupa) &&
564- !promotionImprovesCoalescing (
565- threadIdxXScheduleDepthState,
566- *group,
567- fullSched,
568- activePoints)) {
578+ !promotionImprovesCoalescing (root, bandNode, *group, fullSched)) {
569579 continue ;
570580 }
571581
@@ -586,17 +596,12 @@ void promoteToSharedGreedy(
586596
587597void promoteGreedilyAtDepth (
588598 MappedScop& mscop,
589- const ThreadIdxXScheduleDepthState& threadIdxXScheduleDepthState,
590599 size_t depth,
591600 size_t sharedMemorySize,
592601 bool unrollCopies) {
593602 // 1. Promote using heuristic.
594603 promoteToSharedGreedy (
595- mscop.scop (),
596- threadIdxXScheduleDepthState,
597- mscop.numThreads ,
598- depth,
599- sharedMemorySize);
604+ mscop.scop (), mscop.numThreads , depth, sharedMemorySize);
600605
601606 // 2. Map copies to shared, state by copy
602607 mapCopiesToThreads (mscop, unrollCopies);
0 commit comments