@@ -1461,6 +1461,20 @@ map<IndexVar,Dimension> IndexStmt::getIndexVarDomains() const {
14611461 return indexVarDomains;
14621462}
14631463
1464+
1465+
1466+ IndexStmt IndexStmt::concretizeScheduled (ProvenanceGraph provGraph) const {
1467+ IndexStmt stmt = *this ;
1468+ string r;
1469+ if (isEinsumNotation (stmt)) {
1470+ stmt = makeReductionNotationScheduled (stmt, provGraph);
1471+ }
1472+ if (isReductionNotationScheduled (stmt, provGraph, &r)) {
1473+ stmt = makeConcreteNotationScheduled (stmt, provGraph);
1474+ }
1475+ return stmt;
1476+ }
1477+
14641478IndexStmt IndexStmt::concretize () const {
14651479 IndexStmt stmt = *this ;
14661480 if (isEinsumNotation (stmt)) {
@@ -2322,8 +2336,10 @@ bool isReductionNotation(IndexStmt stmt, std::string* reason) {
23222336 bool isReduction = true ;
23232337
23242338 util::ScopedMap<IndexVar,int > boundVars; // (int) value not used
2339+ vector<IndexVar> boundVarsList;
23252340 for (auto & var : to<Assignment>(stmt).getFreeVars ()) {
23262341 boundVars.insert ({var,0 });
2342+ boundVarsList.push_back (var);
23272343 }
23282344
23292345 match (stmt,
@@ -2347,6 +2363,67 @@ bool isReductionNotation(IndexStmt stmt, std::string* reason) {
23472363 return isReduction;
23482364}
23492365
2366+ bool isReductionNotationScheduled (IndexStmt stmt, ProvenanceGraph provGraph, std::string* reason) {
2367+ INIT_REASON (reason);
2368+
2369+ if (!isa<Assignment>(stmt)) {
2370+ *reason = " reduction notation statements must be assignments" ;
2371+ return false ;
2372+ }
2373+
2374+ if (!isValid (to<Assignment>(stmt), reason)) {
2375+ return false ;
2376+ }
2377+
2378+ // Reduction notation until proved otherwise
2379+ bool isReduction = true ;
2380+
2381+ util::ScopedMap<IndexVar,int > boundVars; // (int) value not used
2382+ vector<IndexVar> boundVarsList;
2383+ for (auto & var : to<Assignment>(stmt).getFreeVars ()) {
2384+ boundVars.insert ({var,0 });
2385+ boundVarsList.push_back (var);
2386+ }
2387+
2388+ match (stmt,
2389+ std::function<void (const ReductionNode*,Matcher*)>([&](
2390+ const ReductionNode* op, Matcher* ctx) {
2391+ boundVars.scope ();
2392+ boundVars.insert ({op->var ,0 });
2393+ ctx->match (op->a );
2394+ boundVars.unscope ();
2395+ }),
2396+ std::function<void (const AccessNode*)>([&](const AccessNode* op) {
2397+ for (auto & var : op->indexVars ) {
2398+ if (!boundVars.contains (var)) {
2399+ // This detects to see if one of the boundVars is an ancestor of var
2400+ // or if boundVars is a descendant of var given the Provenance Graph.
2401+ // If either of these are true, then the statement is still in reduction notation.
2402+ if (provGraph.isFullyDerived (var)) {
2403+ auto ancestors = provGraph.getUnderivedAncestors (var);
2404+ for (auto & ancestor: ancestors) {
2405+ if (boundVars.contains (ancestor)) {
2406+ return true ;
2407+ }
2408+ }
2409+ } else {
2410+ auto descendants = provGraph.getFullyDerivedDescendants (var);
2411+ for (auto & descendant : descendants) {
2412+ if (boundVars.contains (descendant)) {
2413+ return true ;
2414+ }
2415+ }
2416+ }
2417+ *reason = " all reduction variables in reduction notation must be "
2418+ " bound by a reduction expression" ;
2419+ isReduction = false ;
2420+ }
2421+ }
2422+ })
2423+ );
2424+ return isReduction;
2425+ }
2426+
23502427bool isConcreteNotation (IndexStmt stmt, std::string* reason) {
23512428 taco_iassert (stmt.defined ()) << " the index statement is undefined" ;
23522429 INIT_REASON (reason);
@@ -2605,6 +2682,207 @@ IndexStmt makeConcreteNotation(IndexStmt stmt) {
26052682 return stmt;
26062683}
26072684
2685+ Assignment makeReductionNotationScheduled (Assignment assignment, ProvenanceGraph provGraph) {
2686+ IndexExpr expr = assignment.getRhs ();
2687+ std::vector<IndexVar> free = assignment.getLhs ().getIndexVars ();
2688+ if (!isEinsumNotation (assignment)) {
2689+ return assignment;
2690+ }
2691+
2692+ struct MakeReductionNotation : IndexNotationRewriter {
2693+ MakeReductionNotation (const std::vector<IndexVar>& free, ProvenanceGraph provGraph)
2694+ : free(free.begin(), free.end()), provGraph(provGraph){}
2695+
2696+ ProvenanceGraph provGraph;
2697+ std::set<IndexVar> free;
2698+ bool onlyOneTerm;
2699+
2700+ IndexExpr addReductions (IndexExpr expr) {
2701+ auto vars = getIndexVars (expr);
2702+ for (auto & var : util::reverse (vars)) {
2703+
2704+ if (!util::contains (free, var)) {
2705+ bool shouldReduce = true ;
2706+ // / Do not add a reduction node if mismatch is between a fully derived indexVar and its ancestor
2707+ if (provGraph.isFullyDerived (var)) {
2708+ for (auto & f: free) {
2709+ if (provGraph.isDerivedFrom (var, f)) {
2710+ shouldReduce = false ;
2711+ }
2712+ }
2713+ } else {
2714+ for (auto & f: free) {
2715+ if (provGraph.isDerivedFrom (f, var)) {
2716+ shouldReduce = false ;
2717+ }
2718+ }
2719+ }
2720+ if (shouldReduce)
2721+ expr = sum (var,expr);
2722+ }
2723+ }
2724+ return expr;
2725+ }
2726+
2727+ IndexExpr einsum (const IndexExpr& expr) {
2728+ onlyOneTerm = true ;
2729+ IndexExpr einsumexpr = rewrite (expr);
2730+
2731+ if (onlyOneTerm) {
2732+ einsumexpr = addReductions (einsumexpr);
2733+ }
2734+
2735+ return einsumexpr;
2736+ }
2737+
2738+ using IndexNotationRewriter::visit;
2739+
2740+ void visit (const AddNode* op) {
2741+ // Sum every reduction variables over each term
2742+ onlyOneTerm = false ;
2743+
2744+ IndexExpr a = addReductions (op->a );
2745+ IndexExpr b = addReductions (op->b );
2746+ if (a == op->a && b == op->b ) {
2747+ expr = op;
2748+ }
2749+ else {
2750+ expr = new AddNode (a, b);
2751+ }
2752+ }
2753+
2754+ void visit (const SubNode* op) {
2755+ // Sum every reduction variables over each term
2756+ onlyOneTerm = false ;
2757+
2758+ IndexExpr a = addReductions (op->a );
2759+ IndexExpr b = addReductions (op->b );
2760+ if (a == op->a && b == op->b ) {
2761+ expr = op;
2762+ }
2763+ else {
2764+ expr = new SubNode (a, b);
2765+ }
2766+ }
2767+ };
2768+ return Assignment (assignment.getLhs (),
2769+ MakeReductionNotation (free, provGraph).einsum (expr),
2770+ assignment.getOperator ());
2771+ }
2772+
2773+ IndexStmt makeReductionNotationScheduled (IndexStmt stmt, ProvenanceGraph provGraph) {
2774+ taco_iassert (isEinsumNotation (stmt));
2775+ return makeReductionNotationScheduled (to<Assignment>(stmt), provGraph);
2776+ }
2777+
2778+ IndexStmt makeConcreteNotationScheduled (IndexStmt stmt, ProvenanceGraph provGraph) {
2779+ std::string reason;
2780+ taco_iassert (isReductionNotationScheduled (stmt, provGraph, &reason))
2781+ << " Not reduction notation: " << stmt << std::endl << reason;
2782+ taco_iassert (isa<Assignment>(stmt));
2783+
2784+ // Free variables and reductions covering the whole rhs become top level loops
2785+ vector<IndexVar> freeVars = to<Assignment>(stmt).getFreeVars ();
2786+
2787+ struct RemoveTopLevelReductions : IndexNotationRewriter {
2788+ using IndexNotationRewriter::visit;
2789+
2790+ void visit (const AssignmentNode* node) {
2791+ // Easiest to just walk down the reduction node until we find something
2792+ // that's not a reduction
2793+ vector<IndexVar> topLevelReductions;
2794+ IndexExpr rhs = node->rhs ;
2795+ while (isa<Reduction>(rhs)) {
2796+ Reduction reduction = to<Reduction>(rhs);
2797+ topLevelReductions.push_back (reduction.getVar ());
2798+ rhs = reduction.getExpr ();
2799+ }
2800+
2801+ if (rhs != node->rhs ) {
2802+ stmt = Assignment (node->lhs , rhs, Add ());
2803+ for (auto & i : util::reverse (topLevelReductions)) {
2804+ stmt = forall (i, stmt);
2805+ }
2806+ }
2807+ else {
2808+ stmt = node;
2809+ }
2810+ }
2811+ };
2812+ stmt = RemoveTopLevelReductions ().rewrite (stmt);
2813+
2814+ // This gets the list of indexVars on the rhs of an assignment
2815+ // TODO: check to make sure that we want to get ALL rhs indexVars (not just the upper level)
2816+ vector<IndexVar> rhsVars;
2817+ match (stmt,
2818+ function<void (const AccessNode*, Matcher*)>([&](const AccessNode* op, Matcher* ctx) {
2819+ rhsVars.insert (rhsVars.end (), op->indexVars .begin (), op->indexVars .end ());
2820+ }),
2821+ function<void (const AssignmentNode*, Matcher*)>([&](const AssignmentNode* op, Matcher* ctx) {
2822+ ctx->match (op->rhs );
2823+ })
2824+ );
2825+
2826+ // Emit the freeVars as foralls if the freeVars are fully derived
2827+ // else emit the fully derived descendant of the freeVar found in rhsVars
2828+ for (auto & i : util::reverse (freeVars)) {
2829+ if (provGraph.isFullyDerived (i))
2830+ stmt = forall (i, stmt);
2831+ else {
2832+ auto derivedVars = provGraph.getFullyDerivedDescendants (i);
2833+ IndexVar derivedI = *rhsVars.begin ();
2834+ for (auto & derivedVar : derivedVars) {
2835+ if (std::find (rhsVars.begin (), rhsVars.end (), derivedVar) != rhsVars.end ()) {
2836+ derivedI = derivedVar;
2837+ }
2838+ }
2839+ stmt = forall (derivedI, stmt);
2840+ }
2841+ }
2842+
2843+ // Replace other reductions with where and forall statements
2844+ struct ReplaceReductionsWithWheres : IndexNotationRewriter {
2845+ using IndexNotationRewriter::visit;
2846+
2847+ Reduction reduction;
2848+ TensorVar t;
2849+
2850+ void visit (const AssignmentNode* node) {
2851+ reduction = Reduction ();
2852+ t = TensorVar ();
2853+
2854+ IndexExpr rhs = rewrite (node->rhs );
2855+
2856+ // nothing was rewritten
2857+ if (rhs == node->rhs ) {
2858+ stmt = node;
2859+ return ;
2860+ }
2861+
2862+ taco_iassert (t.defined () && reduction.defined ());
2863+ IndexStmt consumer = Assignment (node->lhs , rhs, node->op );
2864+ IndexStmt producer = forall (reduction.getVar (),
2865+ Assignment (t, reduction.getExpr (),
2866+ reduction.getOp ()));
2867+ stmt = where (rewrite (consumer), rewrite (producer));
2868+ }
2869+
2870+ void visit (const ReductionNode* node) {
2871+ // only rewrite one reduction at a time
2872+ if (reduction.defined ()) {
2873+ expr = node;
2874+ return ;
2875+ }
2876+
2877+ reduction = node;
2878+ t = TensorVar (" t" + util::toString (node->var ),
2879+ node->getDataType ());
2880+ expr = t;
2881+ }
2882+ };
2883+ stmt = ReplaceReductionsWithWheres ().rewrite (stmt);
2884+ return stmt;
2885+ }
26082886
26092887vector<TensorVar> getResults (IndexStmt stmt) {
26102888 vector<TensorVar> result;
0 commit comments