@@ -200,37 +200,6 @@ isl::union_map fullSchedule(const detail::ScheduleTree* root) {
200200 return schedule;
201201}
202202
203- /*
204- * Insert map constraints that equate first "nDims" input dimensions to newly
205- * introduced parameters.
206- */
207- isl::map fixOuterInputDimsAsParameters (isl::map map, unsigned nDims) {
208- if (nDims < 0 || nDims > map.dim (isl::dim_type::in)) {
209- std::stringstream ss;
210- ss << nDims << " is out of [0, " << map.dim (isl::dim_type::in)
211- << " ) range" ;
212- throw promotion::OutOfRangeException (ss.str ());
213- }
214-
215- auto fixedMap = map;
216- auto localSpace = isl::local_space (map.get_space ().domain ());
217- auto nParams = map.dim (isl::dim_type::param);
218- localSpace = localSpace.add_dims (isl::dim_type::param, nDims);
219- for (unsigned i = 0 ; i < nDims; ++i) {
220- localSpace = localSpace.set_dim_name (
221- isl::dim_type::param,
222- nParams + i,
223- " __tcFixerParam" + std::to_string (i));
224- }
225- for (unsigned i = 0 ; i < nDims; ++i) {
226- auto left = isl::aff (localSpace, isl::dim_type::param, nParams + i);
227- auto right = isl::aff (localSpace, isl::dim_type::set, i);
228- auto dom = isl::aff_set (left) == right;
229- fixedMap = fixedMap.intersect_domain (dom);
230- }
231- return fixedMap;
232- }
233-
234203/*
235204 * Check if a reference group features reuse within the "outer" schedule.
236205 * In particular, check that for some given point in the outer schedule and
@@ -339,19 +308,25 @@ bool promotionImprovesCoalescing(
339308}
340309
341310/*
342- * Check if the given "group" can be promoted to registers for the given active
343- * domain points under full "schedule" where "nThreads" consecutive dimensions
344- * at "depth"
345- * are mapped to threads (the innermost of them being mapped to thread x).
311+ * Check if the given "group" can be promoted to registers for the given
312+ * mapping to thread identifiers and within the given outer schedule.
346313 *
347314 * In particular, the group's footprint must contain only one element and the
348- * same tensor element should never be accessed by two different threads.
315+ * same tensor element should never be accessed by two different threads
316+ * within the same iteration of the outer schedule.
317+ * The second test is performed by checking that there is only a single
318+ * thread associated to a given pair of tensor element and outer schedule
319+ * iteration.
320+ * Note that the test for a single thread is performed by looking
321+ * at the range of "thread". This range may be larger than the number
322+ * of threads, such that multiple instances may get mapped to the same thread.
323+ * Requiring different such instances is therefore slightly more conservative
324+ * than strictly needed.
349325 */
350326bool isPromotableToRegisterBelowThreads (
351327 const TensorReferenceGroup& group,
352- isl::union_map schedule,
353- size_t depth,
354- size_t nThreads) {
328+ isl::multi_union_pw_aff outer,
329+ isl::multi_union_pw_aff thread) {
355330 auto originalAccesses = group.originalAccesses ();
356331
357332 // Return early if more than one element needs to be stored in registers.
@@ -364,28 +339,11 @@ bool isPromotableToRegisterBelowThreads(
364339 return false ;
365340 }
366341
367- auto scheduledAccesses = originalAccesses.apply_domain (schedule);
368-
369- // Scheduled accesses contain maps from schedule dimensions to tensor
370- // subscripts. Compute the relation between the schedule dimensions
371- // mapped to threads and tensor subscripts by first removing dimensions
372- // following the one mapped to thread x (last one assuming inverse mapping
373- // order), then by equating all dimensions not mapped to threads to
374- // parameters. Promotion to registers is only allowed if the resulting
375- // relation is injective, i.e. the same tensor element is never accessed by
376- // more than one thread. Note that our current check is overly conservative
377- // because different values of schedule dimension may get mapped to the same
378- // thread, in which case they could access the same tensor element.
379- for (auto sa : isl::UnionAsVector<isl::union_map>(scheduledAccesses)) {
380- sa = sa.project_out (
381- isl::dim_type::in, depth, sa.dim (isl::dim_type::in) - depth);
382- sa = fixOuterInputDimsAsParameters (sa, depth - nThreads);
383- if (!sa.is_injective ()) {
384- return false ;
385- }
386- }
342+ auto map = isl::union_map::from (outer);
343+ map = map.range_product (group.originalAccesses ());
344+ map = map.apply_domain (isl::union_map::from (thread));
387345
388- return true ;
346+ return map. is_injective () ;
389347}
390348
391349/*
@@ -573,22 +531,16 @@ void promoteToRegistersBelowThreads(Scop& scop, size_t nRegisters) {
573531
574532 auto root = scop.scheduleRoot ();
575533
576- auto fullSched = fullSchedule (root);
577534 {
578535 auto markers = findThreadSpecificMarkers (root);
579536
580537 for (auto marker : markers) {
581538 auto partialSched = prefixSchedule (root, marker);
582539 // Pure affine schedule without (mapping) filters.
583- auto partialSchedMupa = prefixScheduleMupa (root, marker);
584-
585- auto depth = marker->scheduleDepth (root);
586-
587- // Thread mapping filters are inserted immediately above the members
588- // mapped to threads. The number of intermediate band members
589- // is therefore equal to the number of mapped thread identifiers.
590540 auto mapping = findThreadMappingAncestor (root, marker);
591- size_t nMappedThreads = marker->scheduleDepth (mapping);
541+ auto prefixSchedMupa = prefixScheduleMupa (root, mapping);
542+ auto mapSchedMupa = infixScheduleMupa (root, mapping, marker);
543+ auto partialSchedMupa = prefixSchedMupa.flat_range_product (mapSchedMupa);
592544
593545 auto groupMap = TensorReferenceGroup::accessedBySubtree (marker, scop);
594546 for (auto & tensorGroups : groupMap) {
@@ -603,7 +555,7 @@ void promoteToRegistersBelowThreads(Scop& scop, size_t nRegisters) {
603555 continue ;
604556 }
605557 if (!isPromotableToRegisterBelowThreads (
606- *group, fullSched, depth, nMappedThreads )) {
558+ *group, prefixSchedMupa, mapSchedMupa )) {
607559 continue ;
608560 }
609561 if (!hasReuseWithin (*group, partialSchedMupa)) {
0 commit comments