@@ -333,6 +333,22 @@ std::pair<isl::union_map, isl::union_map> extractAccesses(
333333 return {finder.reads , finder.writes };
334334}
335335
336+ /*
337+ * Take a parametric expression "f" and convert it into an expression
338+ * on the iteration domains in "domain" by reinterpreting the parameters
339+ * as set dimensions according to the corresponding tuples in "map".
340+ */
341+ isl::union_pw_aff
342+ onDomains (isl::aff f, isl::union_set domain, const IterationDomainMap& map) {
343+ auto upa = isl::union_pw_aff::empty (domain.get_space ());
344+ for (auto set : domain.get_set_list ()) {
345+ auto tuple = map.at (set.get_tuple_id ()).tuple ;
346+ auto onSet = isl::union_pw_aff (f.unbind_params_insert_domain (tuple));
347+ upa = upa.union_add (onSet);
348+ }
349+ return upa;
350+ }
351+
336352} // namespace
337353
338354/*
@@ -395,20 +411,12 @@ isl::schedule makeScheduleTreeHelper(
395411
396412 // Create an affine function that defines an ordering for all
397413 // the statements in the body of this loop over the values of
398- // this loop. For each statement in the children we want the
399- // function that maps everything in its space to this
400- // dimension. The spaces may be different, but they'll all have
401- // this loop var at the same index.
402- isl::multi_union_pw_aff mupa;
403- body.get_domain ().foreach_set ([&](isl::set s) {
404- isl::aff newLoopVar (
405- isl::local_space (s.get_space ()), isl::dim_type::set, outer.n ());
406- if (mupa) {
407- mupa = mupa.union_add (isl::union_pw_aff (isl::pw_aff (newLoopVar)));
408- } else {
409- mupa = isl::union_pw_aff (isl::pw_aff (newLoopVar));
410- }
411- });
414+ // this loop. Start from a parametric expression equal
415+ // to the current loop iterator and then convert it to
416+ // a function on the statements in the domain of the body schedule.
417+ auto aff = isl::aff::param_on_domain_space (space, id);
418+ auto domain = body.get_domain ();
419+ auto mupa = isl::multi_union_pw_aff (onDomains (aff, domain, *domains));
412420
413421 schedule = body.insert_partial_schedule (mupa);
414422 } else if (auto op = s.as <Halide::Internal::Block>()) {
0 commit comments