@@ -1463,14 +1463,18 @@ map<IndexVar,Dimension> IndexStmt::getIndexVarDomains() const {
14631463
14641464
14651465
1466- IndexStmt IndexStmt::concretizeScheduled (ProvenanceGraph provGraph) const {
1466+ IndexStmt IndexStmt::concretizeScheduled (ProvenanceGraph provGraph, vector<IndexVar> forallIndexVarList ) const {
14671467 IndexStmt stmt = *this ;
14681468 string r;
1469- if (isEinsumNotation (stmt)) {
1469+ cout << " Pre concretized stmt: " << stmt << endl;
1470+ if (isEinsumNotation (stmt, &r)) {
14701471 stmt = makeReductionNotationScheduled (stmt, provGraph);
1472+ cout << " Post Reduction Stmt: " << stmt << endl;
14711473 }
1474+ cout << r << endl;
14721475 if (isReductionNotationScheduled (stmt, provGraph, &r)) {
1473- stmt = makeConcreteNotationScheduled (stmt, provGraph);
1476+ stmt = makeConcreteNotationScheduled (stmt, provGraph, forallIndexVarList);
1477+ cout << " Post Concretize Stmt: " << stmt << endl;
14741478 }
14751479 return stmt;
14761480}
@@ -2775,17 +2779,22 @@ IndexStmt makeReductionNotationScheduled(IndexStmt stmt, ProvenanceGraph provGra
27752779 return makeReductionNotationScheduled (to<Assignment>(stmt), provGraph);
27762780}
27772781
2778- IndexStmt makeConcreteNotationScheduled (IndexStmt stmt, ProvenanceGraph provGraph) {
2782+ IndexStmt makeConcreteNotationScheduled (IndexStmt stmt, ProvenanceGraph provGraph, vector<IndexVar> forallIndexVars ) {
27792783 std::string reason;
27802784 taco_iassert (isReductionNotationScheduled (stmt, provGraph, &reason))
27812785 << " Not reduction notation: " << stmt << std::endl << reason;
27822786 taco_iassert (isa<Assignment>(stmt));
27832787
27842788 // Free variables and reductions covering the whole rhs become top level loops
27852789 vector<IndexVar> freeVars = to<Assignment>(stmt).getFreeVars ();
2790+ vector<IndexVar> reductionAndFreeVars;
27862791
27872792 struct RemoveTopLevelReductions : IndexNotationRewriter {
27882793 using IndexNotationRewriter::visit;
2794+ vector<IndexVar> forallIndexVars;
2795+ vector<IndexVar> reductionAndFreeVars;
2796+
2797+ RemoveTopLevelReductions (vector<IndexVar> forallIndexVars) : forallIndexVars(forallIndexVars) {}
27892798
27902799 void visit (const AssignmentNode* node) {
27912800 // Easiest to just walk down the reduction node until we find something
@@ -2800,43 +2809,86 @@ IndexStmt makeConcreteNotationScheduled(IndexStmt stmt, ProvenanceGraph provGrap
28002809
28012810 if (rhs != node->rhs ) {
28022811 stmt = Assignment (node->lhs , rhs, Add ());
2803- for (auto & i : util::reverse (topLevelReductions)) {
2804- stmt = forall (i, stmt);
2812+ if (forallIndexVars.empty ()) {
2813+ for (auto &i : util::reverse (topLevelReductions)) {
2814+ stmt = forall (i, stmt);
2815+ }
2816+ } else {
2817+ reductionAndFreeVars.insert (reductionAndFreeVars.end (), topLevelReductions.begin (), topLevelReductions.end ());
28052818 }
28062819 }
28072820 else {
28082821 stmt = node;
28092822 }
28102823 }
28112824 };
2812- stmt = RemoveTopLevelReductions ().rewrite (stmt);
2813-
2825+ auto rewriter = RemoveTopLevelReductions (forallIndexVars);
2826+ stmt = rewriter.rewrite (stmt);
2827+ reductionAndFreeVars = rewriter.reductionAndFreeVars ;
28142828 // This gets the list of indexVars on the rhs of an assignment
28152829 // TODO: check to make sure that we want to get ALL rhs indexVars (not just the upper level)
28162830 vector<IndexVar> rhsVars;
28172831 match (stmt,
28182832 function<void (const AccessNode*, Matcher*)>([&](const AccessNode* op, Matcher* ctx) {
2819- rhsVars.insert (rhsVars.end (), op->indexVars .begin (), op->indexVars .end ());
2833+ for (auto &i : op->indexVars ) {
2834+ if (std::find (rhsVars.begin (), rhsVars.end (), i) == rhsVars.end ()) {
2835+ rhsVars.push_back (i);
2836+ }
2837+ }
28202838 }),
28212839 function<void (const AssignmentNode*, Matcher*)>([&](const AssignmentNode* op, Matcher* ctx) {
28222840 ctx->match (op->rhs );
28232841 })
28242842 );
28252843
2844+ cout << " freeVars: " ;
2845+ for (auto &i : freeVars) {
2846+ cout << i << " , " ;
2847+ }
2848+ cout << endl;
2849+
2850+ cout << " rhsVars: " ;
2851+ for (auto &i : rhsVars) {
2852+ cout << i << " , " ;
2853+ }
2854+ cout << endl;
2855+
2856+ cout << " forallIndexVars: " ;
2857+ for (auto &i : forallIndexVars) {
2858+ cout << i << " , " ;
2859+ }
2860+ cout << endl;
28262861 // Emit the freeVars as foralls if the freeVars are fully derived
28272862 // else emit the fully derived descendant of the freeVar found in rhsVars
2828- for (auto & i : util::reverse (freeVars)) {
2829- if (provGraph.isFullyDerived (i))
2830- stmt = forall (i, stmt);
2831- else {
2832- auto derivedVars = provGraph.getFullyDerivedDescendants (i);
2833- IndexVar derivedI = *rhsVars.begin ();
2834- for (auto & derivedVar : derivedVars) {
2835- if (std::find (rhsVars.begin (), rhsVars.end (), derivedVar) != rhsVars.end ()) {
2836- derivedI = derivedVar;
2863+ if (forallIndexVars.empty ()) {
2864+ for (auto &i : util::reverse (freeVars)) {
2865+ if (provGraph.isFullyDerived (i))
2866+ stmt = forall (i, stmt);
2867+ else {
2868+ auto derivedVars = provGraph.getFullyDerivedDescendants (i);
2869+ IndexVar derivedI = *rhsVars.begin ();
2870+ for (auto &derivedVar : derivedVars) {
2871+ if (std::find (rhsVars.begin (), rhsVars.end (), derivedVar) != rhsVars.end ()) {
2872+ derivedI = derivedVar;
2873+ }
2874+ }
2875+ stmt = forall (derivedI, stmt);
2876+ }
2877+ }
2878+ } else {
2879+ reductionAndFreeVars.insert (reductionAndFreeVars.end (), freeVars.begin (), freeVars.end ());
2880+ for (auto &i : util::reverse (forallIndexVars)) {
2881+ if (std::find (reductionAndFreeVars.begin (), reductionAndFreeVars.end (), i) != reductionAndFreeVars.end ())
2882+ stmt = forall (i, stmt);
2883+ else {
2884+ auto ancestorVars = provGraph.getUnderivedAncestors (i);
2885+ IndexVar ancestorI = *reductionAndFreeVars.begin ();
2886+ for (auto &ancestorVar : ancestorVars) {
2887+ if (std::find (reductionAndFreeVars.begin (), reductionAndFreeVars.end (), ancestorVar) != reductionAndFreeVars.end ()) {
2888+ stmt = forall (i, stmt);
2889+ }
28372890 }
28382891 }
2839- stmt = forall (derivedI, stmt);
28402892 }
28412893 }
28422894
0 commit comments