2020
2121#include " tc/core/check.h"
2222#include " tc/core/constants.h"
23+ #include " tc/core/polyhedral/body.h"
2324#include " tc/core/polyhedral/schedule_isl_conversion.h"
2425#include " tc/core/polyhedral/schedule_transforms.h"
2526#include " tc/core/polyhedral/schedule_tree.h"
@@ -333,6 +334,90 @@ std::pair<isl::union_map, isl::union_map> extractAccesses(
333334 return {finder.reads , finder.writes };
334335}
335336
337+ bool isReductionUpdate (const Provide* op) {
338+ if (const Call* call = op->values [0 ].as <Call>()) {
339+ return call->is_intrinsic (tc2halide::kReductionUpdate );
340+ } else {
341+ return false ;
342+ }
343+ }
344+
345+ /* Construct a multi-dimensional affine function mapping
346+ * the given iteration domain
347+ * to the outer loop iterators that do not appear in "skip".
348+ * "id" is used as the identifier of the target space.
349+ * For each of these outer loop iterators, an affine function
350+ * is first constructed in terms of the parameter space
351+ * active at the point where the iteration domain was created and
352+ * then converted into an expression on that iteration domain
353+ * by reinterpreting the parameters as input dimensions.
354+ */
355+ static isl::multi_aff mapToOther (
356+ const IterationDomain& iterationDomain,
357+ std::unordered_set<std::string> skip,
358+ isl::id id) {
359+ auto ctx = iterationDomain.tuple .get_ctx ();
360+ auto list = isl::aff_list (ctx, 0 );
361+ for (auto id : iterationDomain.tuple .get_id_list ()) {
362+ if (skip.count (id.get_name ()) == 1 ) {
363+ continue ;
364+ }
365+ auto aff = isl::aff::param_on_domain_space (iterationDomain.paramSpace , id);
366+ aff = aff.unbind_params_insert_domain (iterationDomain.tuple );
367+ list = list.add (aff);
368+ }
369+ auto domainSpace = iterationDomain.tuple .get_space ();
370+ auto space = domainSpace.params ().named_set_from_params_id (id, list.size ());
371+ space = domainSpace.product (space).unwrap ();
372+ return isl::multi_aff (space, list);
373+ }
374+
375+ /*
376+ * If "op" performs a reduction, then return a mapping from
377+ * the statement instances to the individual reductions.
378+ * Otherwise, return an empty isl::union_map.
379+ *
380+ * "op" is considered to be a reduction if it has been marked
381+ * as performing a reduction and if more than one statement instance
382+ * is involved in the individual reductions.
383+ *
384+ * The space of the reduction has a name of the form R_<op->name>_<index>.
385+ * Each reduction is indexed by the outer loop variables
386+ * that are not marked as reduction variables.
387+ * Since the loop variables that iterate over output tensor elements
388+ * are never marked as reduction variables, this means in particular
389+ * that all statement instances that belong to the same reduction
390+ * write to the same tensor element.
391+ */
392+ isl::union_map extractReduction (
393+ const IterationDomain& iterationDomain,
394+ const Provide* op,
395+ size_t index) {
396+ class FindReductionVars : public IRVisitor {
397+ void visit (const Variable* op) {
398+ if (op->reduction_domain .defined ()) {
399+ reductionVars.insert (op->name );
400+ }
401+ }
402+
403+ public:
404+ // The variables that are known to be reduction variables.
405+ std::unordered_set<std::string> reductionVars;
406+ } finder;
407+
408+ if (!isReductionUpdate (op)) {
409+ return isl::union_map::empty (iterationDomain.tuple .get_space ().params ());
410+ }
411+ op->accept (&finder);
412+ if (finder.reductionVars .size () == 0 ) {
413+ return isl::union_map::empty (iterationDomain.tuple .get_space ().params ());
414+ }
415+ auto ctx = iterationDomain.tuple .get_ctx ();
416+ isl::id id (ctx, kReductionLabel + op->name + " _" + std::to_string (index));
417+ auto reduction = mapToOther (iterationDomain, finder.reductionVars , id);
418+ return isl::union_map (isl::map (reduction));
419+ }
420+
336421/*
337422 * Take a parametric expression "f" and convert it into an expression
338423 * on the iteration domains in "domain" by reinterpreting the parameters
@@ -360,7 +445,7 @@ onDomains(isl::aff f, isl::union_set domain, const IterationDomainMap& map) {
360445 * from outermost to innermost.
361446 * Return the schedule corresponding to the subtree at "s".
362447 *
363- * "reads" and "writes" collect the accesses found along the way.
448+ * "body" collects the accesses and reductions found along the way.
364449 * "accesses" collects the mapping from Call (for the reads) and Provide nodes
365450 * (for the writes) to the corresponding tag in the access relations.
366451 * "statements" collects the mapping from instance set tuple identifiers
@@ -372,8 +457,7 @@ isl::schedule makeScheduleTreeHelper(
372457 const Stmt& s,
373458 isl::set set,
374459 isl::id_list outer,
375- isl::union_map* reads,
376- isl::union_map* writes,
460+ Body* body,
377461 AccessMap* accesses,
378462 StatementMap* statements,
379463 IterationDomainMap* domains) {
@@ -406,19 +490,19 @@ isl::schedule makeScheduleTreeHelper(
406490
407491 // Recursively descend.
408492 auto outerNext = outer.add (isl::id (set.get_ctx (), op->name ));
409- auto body = makeScheduleTreeHelper (
410- op->body , set, outerNext, reads, writes , accesses, statements, domains);
493+ auto bodySchedule = makeScheduleTreeHelper (
494+ op->body , set, outerNext, body , accesses, statements, domains);
411495
412496 // Create an affine function that defines an ordering for all
413497 // the statements in the body of this loop over the values of
414498 // this loop. Start from a parametric expression equal
415499 // to the current loop iterator and then convert it to
416500 // a function on the statements in the domain of the body schedule.
417501 auto aff = isl::aff::param_on_domain_space (space, id);
418- auto domain = body .get_domain ();
502+ auto domain = bodySchedule .get_domain ();
419503 auto mupa = isl::multi_union_pw_aff (onDomains (aff, domain, *domains));
420504
421- schedule = body .insert_partial_schedule (mupa);
505+ schedule = bodySchedule .insert_partial_schedule (mupa);
422506 } else if (auto op = s.as <Halide::Internal::Block>()) {
423507 std::vector<Stmt> stmts;
424508 stmts.push_back (op->first );
@@ -429,7 +513,7 @@ isl::schedule makeScheduleTreeHelper(
429513 std::vector<isl::schedule> schedules;
430514 for (Stmt stmt : stmts) {
431515 schedules.push_back (makeScheduleTreeHelper (
432- stmt, set, outer, reads, writes , accesses, statements, domains));
516+ stmt, set, outer, body , accesses, statements, domains));
433517 }
434518 schedule = schedules[0 ].sequence (schedules[1 ]);
435519
@@ -452,9 +536,13 @@ isl::schedule makeScheduleTreeHelper(
452536 isl::union_map newReads, newWrites;
453537 std::tie (newReads, newWrites) =
454538 extractAccesses (iterationDomain, op, accesses);
539+ // A tensor may be involved in multiple reductions.
540+ // Use the statement index to differentiate between them.
541+ auto newReduction = extractReduction (iterationDomain, op, stmtIndex);
455542
456- *reads = reads->unite (newReads);
457- *writes = writes->unite (newWrites);
543+ body->reads = body->reads .unite (newReads);
544+ body->writes = body->writes .unite (newWrites);
545+ body->reductions = body->reductions .unite (newReduction);
458546
459547 } else {
460548 LOG (FATAL) << " Unhandled Halide stmt: " << s;
@@ -465,87 +553,24 @@ isl::schedule makeScheduleTreeHelper(
465553ScheduleTreeAndAccesses makeScheduleTree (isl::space paramSpace, const Stmt& s) {
466554 ScheduleTreeAndAccesses result;
467555
468- result. writes = result. reads = isl::union_map::empty (paramSpace);
556+ Body body (paramSpace);
469557
470558 // Walk the IR building a schedule tree
471559 isl::id_list outer (paramSpace.get_ctx (), 0 );
472560 auto schedule = makeScheduleTreeHelper (
473561 s,
474562 isl::set::universe (paramSpace),
475563 outer,
476- &result.reads ,
477- &result.writes ,
564+ &body,
478565 &result.accesses ,
479566 &result.statements ,
480567 &result.domains );
481568
569+ result.body = body;
482570 result.tree = fromIslSchedule (schedule);
483571
484572 return result;
485573}
486574
487- std::vector<Reduction> findReductions (const Stmt& s) {
488- class FindReductions : public IRVisitor {
489- using IRVisitor::visit;
490-
491- bool isReductionUpdate (const Provide* op) {
492- if (const Call* call = op->values [0 ].as <Call>()) {
493- return call->is_intrinsic (tc2halide::kReductionUpdate );
494- } else {
495- return false ;
496- }
497- }
498-
499- // Keep track of any reduction variable name for use in visit(Provide*)
500- void visit (const Variable* op) {
501- if (op->reduction_domain .defined ()) {
502- reductionVars.insert (op->name );
503- }
504- }
505-
506- // Keep track of the names of the outer For nodes.
507- void visit (const For* op) {
508- vars.push_back (op->name );
509- IRVisitor::visit (op);
510- vars.pop_back ();
511- }
512-
513- // Check if the node is an update node with at least one reduction
514- // dimension, keeping track of the information about the reduction.
515- // In particular, collect the positions of the reduction
516- // dimensions in the update statement domain.
517- // Visit the children first to ensure that all relevant
518- // reduction variables have been found first.
519- void visit (const Provide* op) {
520- IRVisitor::visit (op);
521- if (isReductionUpdate (op)) {
522- std::vector<size_t > dims;
523- auto n = vars.size ();
524- for (size_t i = 0 ; i < n; ++i) {
525- if (reductionVars.count (vars[i]) != 0 ) {
526- dims.emplace_back (i);
527- }
528- }
529- if (dims.size () > 0 ) {
530- Reduction p;
531- p.update = op;
532- p.dims = dims;
533- reductions.emplace_back (p);
534- }
535- }
536- }
537-
538- public:
539- // The variables that are known to be reduction variables.
540- std::unordered_set<std::string> reductionVars;
541- // The names of the outer For nodes, outermost to innermost.
542- std::vector<std::string> vars;
543- std::vector<Reduction> reductions;
544- } finder;
545- s.accept (&finder);
546-
547- return finder.reductions ;
548- }
549-
550575} // namespace halide2isl
551576} // namespace tc
0 commit comments