2626#include " taco/tensor.h"
2727
2828#include " taco/util/name_generator.h"
29+ #include " taco/util/scopedset.h"
2930#include " taco/util/scopedmap.h"
3031#include " taco/util/strings.h"
3132#include " taco/util/collections.h"
@@ -1525,11 +1526,9 @@ IndexStmt IndexStmt::divide(IndexVar i, IndexVar i1, IndexVar i2, size_t splitFa
15251526IndexStmt IndexStmt::precompute (IndexExpr expr, std::vector<IndexVar> i_vars,
15261527 std::vector<IndexVar> iw_vars, TensorVar workspace) const {
15271528
1528- // TODO: need to assert they are same length
15291529 IndexStmt transformed = *this ;
15301530 string reason;
15311531
1532- // FIXME: need to re-enable this later
15331532 taco_uassert (i_vars.size () == iw_vars.size ()) << " The precompute transformation requires"
15341533 << " i_vars and iw_vars to be the same size" ;
15351534 for (int l = 0 ; l < (int ) i_vars.size (); l++) {
@@ -2343,18 +2342,18 @@ bool isReductionNotation(IndexStmt stmt, std::string* reason) {
23432342 // Reduction notation until proved otherwise
23442343 bool isReduction = true ;
23452344
2346- util::ScopedMap <IndexVar, int > boundVars; // (int) value not used
2345+ util::ScopedSet <IndexVar> boundVars;
23472346 vector<IndexVar> boundVarsList;
23482347 for (auto & var : to<Assignment>(stmt).getFreeVars ()) {
2349- boundVars.insert ({var, 0 });
2348+ boundVars.insert ({var});
23502349 boundVarsList.push_back (var);
23512350 }
23522351
23532352 match (stmt,
23542353 std::function<void (const ReductionNode*,Matcher*)>([&](
23552354 const ReductionNode* op, Matcher* ctx) {
23562355 boundVars.scope ();
2357- boundVars.insert ({op->var , 0 });
2356+ boundVars.insert ({op->var });
23582357 ctx->match (op->a );
23592358 boundVars.unscope ();
23602359 }),
@@ -2386,18 +2385,18 @@ bool isReductionNotationScheduled(IndexStmt stmt, ProvenanceGraph provGraph, std
23862385 // Reduction notation until proved otherwise
23872386 bool isReduction = true ;
23882387
2389- util::ScopedMap <IndexVar, int > boundVars; // (int) value not used
2388+ util::ScopedSet <IndexVar> boundVars;
23902389 vector<IndexVar> boundVarsList;
23912390 for (auto & var : to<Assignment>(stmt).getFreeVars ()) {
2392- boundVars.insert ({var, 0 });
2391+ boundVars.insert ({var});
23932392 boundVarsList.push_back (var);
23942393 }
23952394
23962395 match (stmt,
23972396 std::function<void (const ReductionNode*,Matcher*)>([&](
23982397 const ReductionNode* op, Matcher* ctx) {
23992398 boundVars.scope ();
2400- boundVars.insert ({op->var , 0 });
2399+ boundVars.insert ({op->var });
24012400 ctx->match (op->a );
24022401 boundVars.unscope ();
24032402 }),
@@ -2441,7 +2440,7 @@ bool isConcreteNotation(IndexStmt stmt, std::string* reason) {
24412440
24422441 bool inWhereProducer = false ;
24432442 bool inWhereConsumer = false ;
2444- util::ScopedMap <IndexVar, int > boundVars; // (int) value not used
2443+ util::ScopedSet <IndexVar> boundVars;
24452444 std::set<IndexVar> definedVars; // used to check if all variables recoverable TODO: need to actually use scope like above
24462445
24472446 ProvenanceGraph provGraph = ProvenanceGraph (stmt);
@@ -2450,7 +2449,7 @@ bool isConcreteNotation(IndexStmt stmt, std::string* reason) {
24502449 std::function<void (const ForallNode*,Matcher*)>([&](const ForallNode* op,
24512450 Matcher* ctx) {
24522451 boundVars.scope ();
2453- boundVars.insert ({op->indexVar , 0 });
2452+ boundVars.insert ({op->indexVar });
24542453 definedVars.insert (op->indexVar );
24552454 ctx->match (op->stmt );
24562455 boundVars.unscope ();
@@ -2606,6 +2605,47 @@ IndexStmt makeReductionNotation(IndexStmt stmt) {
26062605 return makeReductionNotation (to<Assignment>(stmt));
26072606}
26082607
2608+ // Replace other reductions with where and forall statements
2609+ struct ReplaceReductionsWithWheres : IndexNotationRewriter {
2610+ using IndexNotationRewriter::visit;
2611+
2612+ Reduction reduction;
2613+ TensorVar t;
2614+
2615+ void visit (const AssignmentNode* node) {
2616+ reduction = Reduction ();
2617+ t = TensorVar ();
2618+
2619+ IndexExpr rhs = rewrite (node->rhs );
2620+
2621+ // nothing was rewritten
2622+ if (rhs == node->rhs ) {
2623+ stmt = node;
2624+ return ;
2625+ }
2626+
2627+ taco_iassert (t.defined () && reduction.defined ());
2628+ IndexStmt consumer = Assignment (node->lhs , rhs, node->op );
2629+ IndexStmt producer = forall (reduction.getVar (),
2630+ Assignment (t, reduction.getExpr (),
2631+ reduction.getOp ()));
2632+ stmt = where (rewrite (consumer), rewrite (producer));
2633+ }
2634+
2635+ void visit (const ReductionNode* node) {
2636+ // only rewrite one reduction at a time
2637+ if (reduction.defined ()) {
2638+ expr = node;
2639+ return ;
2640+ }
2641+
2642+ reduction = node;
2643+ t = TensorVar (" t" + util::toString (node->var ),
2644+ node->getDataType ());
2645+ expr = t;
2646+ }
2647+ };
2648+
26092649IndexStmt makeConcreteNotation (IndexStmt stmt) {
26102650 std::string reason;
26112651 taco_iassert (isReductionNotation (stmt, &reason))
@@ -2646,46 +2686,6 @@ IndexStmt makeConcreteNotation(IndexStmt stmt) {
26462686 stmt = forall (i, stmt);
26472687 }
26482688
2649- // Replace other reductions with where and forall statements
2650- struct ReplaceReductionsWithWheres : IndexNotationRewriter {
2651- using IndexNotationRewriter::visit;
2652-
2653- Reduction reduction;
2654- TensorVar t;
2655-
2656- void visit (const AssignmentNode* node) {
2657- reduction = Reduction ();
2658- t = TensorVar ();
2659-
2660- IndexExpr rhs = rewrite (node->rhs );
2661-
2662- // nothing was rewritten
2663- if (rhs == node->rhs ) {
2664- stmt = node;
2665- return ;
2666- }
2667-
2668- taco_iassert (t.defined () && reduction.defined ());
2669- IndexStmt consumer = Assignment (node->lhs , rhs, node->op );
2670- IndexStmt producer = forall (reduction.getVar (),
2671- Assignment (t, reduction.getExpr (),
2672- reduction.getOp ()));
2673- stmt = where (rewrite (consumer), rewrite (producer));
2674- }
2675-
2676- void visit (const ReductionNode* node) {
2677- // only rewrite one reduction at a time
2678- if (reduction.defined ()) {
2679- expr = node;
2680- return ;
2681- }
2682-
2683- reduction = node;
2684- t = TensorVar (" t" + util::toString (node->var ),
2685- node->getDataType ());
2686- expr = t;
2687- }
2688- };
26892689 stmt = ReplaceReductionsWithWheres ().rewrite (stmt);
26902690 return stmt;
26912691}
@@ -2882,46 +2882,6 @@ IndexStmt makeConcreteNotationScheduled(IndexStmt stmt, ProvenanceGraph provGrap
28822882 }
28832883 }
28842884
2885- // Replace other reductions with where and forall statements
2886- struct ReplaceReductionsWithWheres : IndexNotationRewriter {
2887- using IndexNotationRewriter::visit;
2888-
2889- Reduction reduction;
2890- TensorVar t;
2891-
2892- void visit (const AssignmentNode* node) {
2893- reduction = Reduction ();
2894- t = TensorVar ();
2895-
2896- IndexExpr rhs = rewrite (node->rhs );
2897-
2898- // nothing was rewritten
2899- if (rhs == node->rhs ) {
2900- stmt = node;
2901- return ;
2902- }
2903-
2904- taco_iassert (t.defined () && reduction.defined ());
2905- IndexStmt consumer = Assignment (node->lhs , rhs, node->op );
2906- IndexStmt producer = forall (reduction.getVar (),
2907- Assignment (t, reduction.getExpr (),
2908- reduction.getOp ()));
2909- stmt = where (rewrite (consumer), rewrite (producer));
2910- }
2911-
2912- void visit (const ReductionNode* node) {
2913- // only rewrite one reduction at a time
2914- if (reduction.defined ()) {
2915- expr = node;
2916- return ;
2917- }
2918-
2919- reduction = node;
2920- t = TensorVar (" t" + util::toString (node->var ),
2921- node->getDataType ());
2922- expr = t;
2923- }
2924- };
29252885 stmt = ReplaceReductionsWithWheres ().rewrite (stmt);
29262886 return stmt;
29272887}
0 commit comments