@@ -251,31 +251,25 @@ isl::set makeParamContext(isl::ctx ctx, const ParameterVector& params) {
251251namespace {
252252
253253isl::map extractAccess (
254- isl::set domain,
254+ const IterationDomain& domain,
255255 const IRNode* op,
256256 const std::string& tensor,
257257 const std::vector<Expr>& args,
258258 AccessMap* accesses) {
259259 // Make an isl::map representing this access. It maps from the iteration space
260260 // to the tensor's storage space, using the coordinates accessed.
261+ // First construct a set describing the accessed element
262+ // in terms of the parameters (including those corresponding
263+ // to the outer loop iterators) and then convert this set
264+ // into a map in terms of the iteration domain.
261265
262- isl::space domainSpace = domain.get_space ();
263- isl::space paramSpace = domainSpace.params ();
266+ isl::space paramSpace = domain.paramSpace ;
264267 isl::id tensorID (paramSpace.get_ctx (), tensor);
265- auto rangeSpace = paramSpace.named_set_from_params_id (tensorID, args.size ());
268+ auto tensorSpace = paramSpace.named_set_from_params_id (tensorID, args.size ());
266269
267- // Add a tag to the domain space so that we can maintain a mapping
268- // between each access in the IR and the reads/writes maps.
269- std::string tag = " __tc_ref_" + std::to_string (accesses->size ());
270- isl::id tagID (domain.get_ctx (), tag);
271- accesses->emplace (op, tagID);
272- isl::space tagSpace = paramSpace.named_set_from_params_id (tagID, 0 );
273- domainSpace = domainSpace.product (tagSpace);
274-
275- // Start with a totally unconstrained relation - every point in
276- // the iteration domain could write to every point in the allocation.
277- isl::map map =
278- isl::map::universe (domainSpace.map_from_domain_and_range (rangeSpace));
270+ // Start with a totally unconstrained set - every point in
271+ // the allocation could be accessed.
272+ isl::set access = isl::set::universe (tensorSpace);
279273
280274 for (size_t i = 0 ; i < args.size (); i++) {
281275 // Then add one equality constraint per dimension to encode the
@@ -285,19 +279,34 @@ isl::map extractAccess(
285279
286280 // The coordinate written to in the range ...
287281 auto rangePoint =
288- isl::pw_aff (isl::local_space (rangeSpace ), isl::dim_type::set, i);
289- // ... equals the coordinate accessed as a function of the domain .
290- auto domainPoint = halide2isl::makeIslAffFromExpr (domainSpace , args[i]);
282+ isl::pw_aff (isl::local_space (tensorSpace ), isl::dim_type::set, i);
283+ // ... equals the coordinate accessed as a function of the parameters .
284+ auto domainPoint = halide2isl::makeIslAffFromExpr (tensorSpace , args[i]);
291285 if (!domainPoint.is_null ()) {
292- map = map .intersect (isl::pw_aff (domainPoint).eq_map (rangePoint));
286+ access = access .intersect (isl::pw_aff (domainPoint).eq_set (rangePoint));
293287 }
294288 }
295289
290+ // Now convert the set into a relation with respect to the iteration domain.
291+ auto map = access.unbind_params_insert_domain (domain.tuple );
292+
293+ // Add a tag to the domain space so that we can maintain a mapping
294+ // between each access in the IR and the reads/writes maps.
295+ std::string tag = " __tc_ref_" + std::to_string (accesses->size ());
296+ isl::id tagID (domain.paramSpace .get_ctx (), tag);
297+ accesses->emplace (op, tagID);
298+ isl::space domainSpace = map.get_space ().domain ();
299+ isl::space tagSpace = domainSpace.params ().named_set_from_params_id (tagID, 0 );
300+ domainSpace = domainSpace.product (tagSpace).unwrap ();
301+ map = map.preimage_domain (isl::multi_aff::domain_map (domainSpace));
302+
296303 return map;
297304}
298305
299- std::pair<isl::union_map, isl::union_map>
300- extractAccesses (isl::set domain, const Stmt& s, AccessMap* accesses) {
306+ std::pair<isl::union_map, isl::union_map> extractAccesses (
307+ const IterationDomain& domain,
308+ const Stmt& s,
309+ AccessMap* accesses) {
301310 class FindAccesses : public IRGraphVisitor {
302311 using IRGraphVisitor::visit;
303312
@@ -315,17 +324,17 @@ extractAccesses(isl::set domain, const Stmt& s, AccessMap* accesses) {
315324 writes.unite (extractAccess (domain, op, op->name , op->args , accesses));
316325 }
317326
318- const isl::set & domain;
327+ const IterationDomain & domain;
319328 AccessMap* accesses;
320329
321330 public:
322331 isl::union_map reads, writes;
323332
324- FindAccesses (const isl::set & domain, AccessMap* accesses)
333+ FindAccesses (const IterationDomain & domain, AccessMap* accesses)
325334 : domain(domain),
326335 accesses (accesses),
327- reads(isl::union_map::empty(domain.get_space())),
328- writes(isl::union_map::empty(domain.get_space())) {}
336+ reads(isl::union_map::empty(domain.tuple. get_space())),
337+ writes(isl::union_map::empty(domain.tuple. get_space())) {}
329338 } finder(domain, accesses);
330339 s.accept(&finder);
331340 return {finder.reads , finder.writes };
@@ -440,7 +449,8 @@ isl::schedule makeScheduleTreeHelper(
440449 schedule = isl::schedule::from_domain (domain);
441450
442451 isl::union_map newReads, newWrites;
443- std::tie (newReads, newWrites) = extractAccesses (domain, op, accesses);
452+ std::tie (newReads, newWrites) =
453+ extractAccesses (iterationDomain, op, accesses);
444454
445455 *reads = reads->unite (newReads);
446456 *writes = writes->unite (newWrites);
0 commit comments