@@ -174,16 +174,9 @@ std::vector<isl::aff> makeIslAffBoundsFromExpr(
174174 const Max* maxOp = e.as <Max>();
175175
176176 if (const Variable* op = e.as <Variable>()) {
177- isl::local_space ls = isl::local_space (space);
178- int pos = space.find_dim_by_name (isl::dim_type::param, op->name );
179- if (pos >= 0 ) {
180- return {isl::aff (ls, isl::dim_type::param, pos)};
181- } else {
182- // FIXME: thou shalt not rely upon set dimension names
183- pos = space.find_dim_by_name (isl::dim_type::set, op->name );
184- if (pos >= 0 ) {
185- return {isl::aff (ls, isl::dim_type::set, pos)};
186- }
177+ isl::id id (space.get_ctx (), op->name );
178+ if (space.has_param (id)) {
179+ return {isl::aff::param_on_domain_space (space, id)};
187180 }
188181 LOG (FATAL) << " Variable not found in isl::space: " << space << " : " << op
189182 << " : " << op->name << ' \n ' ;
@@ -248,32 +241,28 @@ isl::set makeParamContext(isl::ctx ctx, const ParameterVector& params) {
248241 return context;
249242}
250243
244+ namespace {
245+
251246isl::map extractAccess (
252- isl::set domain,
247+ const IterationDomain& domain,
253248 const IRNode* op,
254249 const std::string& tensor,
255250 const std::vector<Expr>& args,
256251 AccessMap* accesses) {
257252 // Make an isl::map representing this access. It maps from the iteration space
258253 // to the tensor's storage space, using the coordinates accessed.
254+ // First construct a set describing the accessed element
255+ // in terms of the parameters (including those corresponding
256+ // to the outer loop iterators) and then convert this set
257+ // into a map in terms of the iteration domain.
259258
260- isl::space domainSpace = domain.get_space ();
261- isl::space paramSpace = domainSpace.params ();
259+ isl::space paramSpace = domain.paramSpace ;
262260 isl::id tensorID (paramSpace.get_ctx (), tensor);
263- auto rangeSpace = paramSpace.named_set_from_params_id (tensorID, args.size ());
261+ auto tensorSpace = paramSpace.named_set_from_params_id (tensorID, args.size ());
264262
265- // Add a tag to the domain space so that we can maintain a mapping
266- // between each access in the IR and the reads/writes maps.
267- std::string tag = " __tc_ref_" + std::to_string (accesses->size ());
268- isl::id tagID (domain.get_ctx (), tag);
269- accesses->emplace (op, tagID);
270- isl::space tagSpace = paramSpace.named_set_from_params_id (tagID, 0 );
271- domainSpace = domainSpace.product (tagSpace);
272-
273- // Start with a totally unconstrained relation - every point in
274- // the iteration domain could write to every point in the allocation.
275- isl::map map =
276- isl::map::universe (domainSpace.map_from_domain_and_range (rangeSpace));
263+ // Start with a totally unconstrained set - every point in
264+ // the allocation could be accessed.
265+ isl::set access = isl::set::universe (tensorSpace);
277266
278267 for (size_t i = 0 ; i < args.size (); i++) {
279268 // Then add one equality constraint per dimension to encode the
@@ -283,19 +272,34 @@ isl::map extractAccess(
283272
284273 // The coordinate written to in the range ...
285274 auto rangePoint =
286- isl::pw_aff (isl::local_space (rangeSpace ), isl::dim_type::set, i);
287- // ... equals the coordinate accessed as a function of the domain .
288- auto domainPoint = halide2isl::makeIslAffFromExpr (domainSpace , args[i]);
275+ isl::pw_aff (isl::local_space (tensorSpace ), isl::dim_type::set, i);
276+ // ... equals the coordinate accessed as a function of the parameters .
277+ auto domainPoint = halide2isl::makeIslAffFromExpr (tensorSpace , args[i]);
289278 if (!domainPoint.is_null ()) {
290- map = map .intersect (isl::pw_aff (domainPoint).eq_map (rangePoint));
279+ access = access .intersect (isl::pw_aff (domainPoint).eq_set (rangePoint));
291280 }
292281 }
293282
283+ // Now convert the set into a relation with respect to the iteration domain.
284+ auto map = access.unbind_params_insert_domain (domain.tuple );
285+
286+ // Add a tag to the domain space so that we can maintain a mapping
287+ // between each access in the IR and the reads/writes maps.
288+ std::string tag = " __tc_ref_" + std::to_string (accesses->size ());
289+ isl::id tagID (domain.paramSpace .get_ctx (), tag);
290+ accesses->emplace (op, tagID);
291+ isl::space domainSpace = map.get_space ().domain ();
292+ isl::space tagSpace = domainSpace.params ().named_set_from_params_id (tagID, 0 );
293+ domainSpace = domainSpace.product (tagSpace).unwrap ();
294+ map = map.preimage_domain (isl::multi_aff::domain_map (domainSpace));
295+
294296 return map;
295297}
296298
297- std::pair<isl::union_map, isl::union_map>
298- extractAccesses (isl::set domain, const Stmt& s, AccessMap* accesses) {
299+ std::pair<isl::union_map, isl::union_map> extractAccesses (
300+ const IterationDomain& domain,
301+ const Stmt& s,
302+ AccessMap* accesses) {
299303 class FindAccesses : public IRGraphVisitor {
300304 using IRGraphVisitor::visit;
301305
@@ -313,28 +317,46 @@ extractAccesses(isl::set domain, const Stmt& s, AccessMap* accesses) {
313317 writes.unite (extractAccess (domain, op, op->name , op->args , accesses));
314318 }
315319
316- const isl::set & domain;
320+ const IterationDomain & domain;
317321 AccessMap* accesses;
318322
319323 public:
320324 isl::union_map reads, writes;
321325
322- FindAccesses (const isl::set & domain, AccessMap* accesses)
326+ FindAccesses (const IterationDomain & domain, AccessMap* accesses)
323327 : domain(domain),
324328 accesses (accesses),
325- reads(isl::union_map::empty(domain.get_space())),
326- writes(isl::union_map::empty(domain.get_space())) {}
329+ reads(isl::union_map::empty(domain.tuple. get_space())),
330+ writes(isl::union_map::empty(domain.tuple. get_space())) {}
327331 } finder(domain, accesses);
328332 s.accept(&finder);
329333 return {finder.reads , finder.writes };
330334}
331335
336+ /*
337+ * Take a parametric expression "f" and convert it into an expression
338+ * on the iteration domains in "domain" by reinterpreting the parameters
339+ * as set dimensions according to the corresponding tuples in "map".
340+ */
341+ isl::union_pw_aff
342+ onDomains (isl::aff f, isl::union_set domain, const IterationDomainMap& map) {
343+ auto upa = isl::union_pw_aff::empty (domain.get_space ());
344+ for (auto set : domain.get_set_list ()) {
345+ auto tuple = map.at (set.get_tuple_id ()).tuple ;
346+ auto onSet = isl::union_pw_aff (f.unbind_params_insert_domain (tuple));
347+ upa = upa.union_add (onSet);
348+ }
349+ return upa;
350+ }
351+
352+ } // namespace
353+
332354/*
333355 * Helper function for extracting a schedule from a Halide Stmt,
334356 * recursively descending over the Stmt.
335357 * "s" is the current position in the recursive descent.
336358 * "set" describes the bounds on the outer loop iterators.
337- * "outer" contains the names of the outer loop iterators
359+ * "outer" contains the identifiers of the outer loop iterators
338360 * from outermost to innermost.
339361 * Return the schedule corresponding to the subtree at "s".
340362 *
@@ -343,81 +365,58 @@ extractAccesses(isl::set domain, const Stmt& s, AccessMap* accesses) {
343365 * (for the writes) to the corresponding tag in the access relations.
344366 * "statements" collects the mapping from instance set tuple identifiers
345367 * to the corresponding Provide node.
346- * "iterators " collects the mapping from instance set tuple identifiers
347- * to the corresponding outer loop iterator names, from outermost to innermost .
368+ * "domains " collects the mapping from instance set tuple identifiers
369+ * to the corresponding iteration domain information .
348370 */
349371isl::schedule makeScheduleTreeHelper (
350372 const Stmt& s,
351373 isl::set set,
352- std::vector<std::string>& outer,
374+ isl::id_list outer,
353375 isl::union_map* reads,
354376 isl::union_map* writes,
355377 AccessMap* accesses,
356378 StatementMap* statements,
357- IteratorMap* iterators ) {
379+ IterationDomainMap* domains ) {
358380 isl::schedule schedule;
359381 if (auto op = s.as <For>()) {
360- // Add one additional dimension to our set of loop variables
361- int thisLoopIdx = set.dim (isl::dim_type::set);
362- set = set.add_dims (isl::dim_type::set, 1 );
363-
364- // Make an id for this loop var. For set dimensions this is
365- // really just for pretty-printing.
382+ // Make an id for this loop var. It starts out as a parameter.
366383 isl::id id (set.get_ctx (), op->name );
367- set = set.set_dim_id (isl::dim_type::set, thisLoopIdx, id);
384+ auto space = set.get_space (). add_param ( id);
368385
369- // Construct a variable (affine function) that indexes the new dimension of
370- // this space.
371- isl::aff loopVar (
372- isl::local_space (set.get_space ()), isl::dim_type::set, thisLoopIdx);
386+ // Construct a variable (affine function) that references
387+ // the new parameter.
388+ auto loopVar = isl::aff::param_on_domain_space (space, id);
373389
374390 // Then we add our new loop bound constraints.
375- auto lbs = halide2isl::makeIslAffBoundsFromExpr (
376- set. get_space () , op->min , false , true );
391+ auto lbs =
392+ halide2isl::makeIslAffBoundsFromExpr (space , op->min , false , true );
377393 TC_CHECK_GT (lbs.size (), 0u )
378394 << " could not obtain polyhedral lower bounds from " << op->min ;
379395 for (auto lb : lbs) {
380396 set = set.intersect (loopVar.ge_set (lb));
381397 }
382398
383399 Expr max = simplify (op->min + op->extent - 1 );
384- auto ubs =
385- halide2isl::makeIslAffBoundsFromExpr (set.get_space (), max, true , false );
400+ auto ubs = halide2isl::makeIslAffBoundsFromExpr (space, max, true , false );
386401 TC_CHECK_GT (ubs.size (), 0u )
387402 << " could not obtain polyhedral upper bounds from " << max;
388403 for (auto ub : ubs) {
389404 set = set.intersect (ub.ge_set (loopVar));
390405 }
391406
392407 // Recursively descend.
393- auto outerNext = outer;
394- outerNext.push_back (op->name );
408+ auto outerNext = outer.add (isl::id (set.get_ctx (), op->name ));
395409 auto body = makeScheduleTreeHelper (
396- op->body ,
397- set,
398- outerNext,
399- reads,
400- writes,
401- accesses,
402- statements,
403- iterators);
410+ op->body , set, outerNext, reads, writes, accesses, statements, domains);
404411
405412 // Create an affine function that defines an ordering for all
406413 // the statements in the body of this loop over the values of
407- // this loop. For each statement in the children we want the
408- // function that maps everything in its space to this
409- // dimension. The spaces may be different, but they'll all have
410- // this loop var at the same index.
411- isl::multi_union_pw_aff mupa;
412- body.get_domain ().foreach_set ([&](isl::set s) {
413- isl::aff newLoopVar (
414- isl::local_space (s.get_space ()), isl::dim_type::set, thisLoopIdx);
415- if (mupa) {
416- mupa = mupa.union_add (isl::union_pw_aff (isl::pw_aff (newLoopVar)));
417- } else {
418- mupa = isl::union_pw_aff (isl::pw_aff (newLoopVar));
419- }
420- });
414+ // this loop. Start from a parametric expression equal
415+ // to the current loop iterator and then convert it to
416+ // a function on the statements in the domain of the body schedule.
417+ auto aff = isl::aff::param_on_domain_space (space, id);
418+ auto domain = body.get_domain ();
419+ auto mupa = isl::multi_union_pw_aff (onDomains (aff, domain, *domains));
421420
422421 schedule = body.insert_partial_schedule (mupa);
423422 } else if (auto op = s.as <Halide::Internal::Block>()) {
@@ -430,7 +429,7 @@ isl::schedule makeScheduleTreeHelper(
430429 std::vector<isl::schedule> schedules;
431430 for (Stmt stmt : stmts) {
432431 schedules.push_back (makeScheduleTreeHelper (
433- stmt, set, outer, reads, writes, accesses, statements, iterators ));
432+ stmt, set, outer, reads, writes, accesses, statements, domains ));
434433 }
435434 schedule = schedules[0 ].sequence (schedules[1 ]);
436435
@@ -441,13 +440,18 @@ isl::schedule makeScheduleTreeHelper(
441440 size_t stmtIndex = statements->size ();
442441 isl::id id (set.get_ctx (), kStatementLabel + std::to_string (stmtIndex));
443442 statements->emplace (id, op);
444- iterators->emplace (id, outer);
445- isl::set domain = set.set_tuple_id (id);
443+ auto tupleSpace = isl::space (set.get_ctx (), 0 );
444+ tupleSpace = tupleSpace.named_set_from_params_id (id, outer.n ());
445+ IterationDomain iterationDomain;
446+ iterationDomain.paramSpace = set.get_space ();
447+ iterationDomain.tuple = isl::multi_id (tupleSpace, outer);
448+ domains->emplace (id, iterationDomain);
449+ auto domain = set.unbind_params (iterationDomain.tuple );
446450 schedule = isl::schedule::from_domain (domain);
447451
448452 isl::union_map newReads, newWrites;
449453 std::tie (newReads, newWrites) =
450- halide2isl:: extractAccesses (domain , op, accesses);
454+ extractAccesses (iterationDomain , op, accesses);
451455
452456 *reads = reads->unite (newReads);
453457 *writes = writes->unite (newWrites);
@@ -464,7 +468,7 @@ ScheduleTreeAndAccesses makeScheduleTree(isl::space paramSpace, const Stmt& s) {
464468 result.writes = result.reads = isl::union_map::empty (paramSpace);
465469
466470 // Walk the IR building a schedule tree
467- std::vector<std::string> outer;
471+ isl::id_list outer (paramSpace. get_ctx (), 0 ) ;
468472 auto schedule = makeScheduleTreeHelper (
469473 s,
470474 isl::set::universe (paramSpace),
@@ -473,7 +477,7 @@ ScheduleTreeAndAccesses makeScheduleTree(isl::space paramSpace, const Stmt& s) {
473477 &result.writes ,
474478 &result.accesses ,
475479 &result.statements ,
476- &result.iterators );
480+ &result.domains );
477481
478482 result.tree = fromIslSchedule (schedule);
479483
0 commit comments