Skip to content

Commit 54fba50

Browse files
committed
Fix workspaces multi-dimensional transformation algorithm for precompute(...) and fix multi-dimensional lowering
1 parent 372887d commit 54fba50

File tree

9 files changed

+679
-35
lines changed

9 files changed

+679
-35
lines changed

include/taco/index_notation/index_notation.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,10 @@ class IndexStmt : public util::IntrusivePtr<const IndexStmtNode> {
549549
/// Takes any index notation and concretizes unknowns to make it concrete notation
550550
IndexStmt concretize() const;
551551

552+
/// Takes any index notation and concretizes unknowns to make it concrete notation
553+
/// given a Provenance Graph of indexVars
554+
IndexStmt concretizeScheduled(ProvenanceGraph provGraph) const;
555+
552556
/// The \code{split} transformation splits (strip-mines) an index
553557
/// variable into two nested index variables, where the size of the
554558
/// inner index variable is constant. The size of the outer index
@@ -1126,6 +1130,11 @@ bool isEinsumNotation(IndexStmt, std::string* reason=nullptr);
11261130
/// notation is printed to.
11271131
bool isReductionNotation(IndexStmt, std::string* reason=nullptr);
11281132

1133+
/// Check whether the statement is in the reduction index notation dialect
1134+
/// given a schedule described by the Provenance Graph
1135+
bool isReductionNotationScheduled(IndexStmt, ProvenanceGraph, std::string* reason=nullptr);
1136+
1137+
11291138
/// Check whether the statement is in the concrete index notation dialect.
11301139
/// This means every index variable has a forall node, there are no reduction
11311140
/// nodes, and that every reduction variable use is nested inside a compound
@@ -1143,6 +1152,18 @@ IndexStmt makeReductionNotation(IndexStmt);
11431152
/// as needed.
11441153
IndexStmt makeConcreteNotation(IndexStmt);
11451154

1155+
1156+
/// Convert einsum notation to reduction notation, by applying Einstein's
1157+
/// summation convention to sum non-free/reduction variables over their term
1158+
/// taking into account a schedule given by the Provenance Graph.
1159+
Assignment makeReductionNotationScheduled(Assignment, ProvenanceGraph);
1160+
IndexStmt makeReductionNotationScheduled(IndexStmt, ProvenanceGraph);
1161+
1162+
/// Convert reduction notation to concrete notation, by inserting forall nodes,
1163+
/// replacing reduction nodes by compound assignments, and inserting temporaries
1164+
/// as needed while taking into account a schedule given by the Provenance Graph.
1165+
IndexStmt makeConcreteNotationScheduled(IndexStmt, ProvenanceGraph);
1166+
11461167
/// Returns the results of the index statement, in the order they appear.
11471168
std::vector<TensorVar> getResults(IndexStmt stmt);
11481169

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#ifndef TACO_WORKSPACE_REWRITER_H
2+
#define TACO_WORKSPACE_REWRITER_H
3+
4+
#include <vector>
5+
#include <map>
6+
7+
8+
namespace taco {
9+
class TensorVar;
10+
11+
namespace ir {
12+
class Stmt;
13+
class Expr;
14+
}
15+
16+
/// Simplifies a statement (e.g. by applying constant copy propagation).
17+
ir::Stmt rewriteTemporaryGP(const ir::Stmt& stmt, std::vector<TensorVar> whereTemps,
18+
std::map<TensorVar, std::vector<ir::Expr>> temporarySizeMap);
19+
20+
}
21+
#endif

include/taco/lower/lowerer_impl.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,9 @@ class LowererImpl : public util::Uncopyable {
356356
/// Gets the size of a temporary tensorVar in the where statement
357357
ir::Expr getTemporarySize(Where where);
358358

359+
/// Gets the varDecl of temporary dimensions for dense workspaces only
360+
ir::Stmt getTemporarySizeDecl(Where where);
361+
359362
/// Initializes helper arrays to give dense workspaces sparse acceleration
360363
std::vector<ir::Stmt> codeToInitializeDenseAcceleratorArrays(Where where, bool parallel = false);
361364

@@ -493,6 +496,9 @@ class LowererImpl : public util::Uncopyable {
493496
std::vector<TensorVar> whereTemps;
494497
std::map<TensorVar, const AccessNode *> whereTempsToResult;
495498

499+
std::map<TensorVar, std::vector<ir::Expr>> temporarySizeMap;
500+
std::vector<TensorVar> temporaries;
501+
496502
bool captureNextLocatePos = false;
497503
ir::Stmt capturedLocatePos; // used for whereConsumer when want to replicate same locating
498504

src/index_notation/index_notation.cpp

Lines changed: 278 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1461,6 +1461,20 @@ map<IndexVar,Dimension> IndexStmt::getIndexVarDomains() const {
14611461
return indexVarDomains;
14621462
}
14631463

1464+
1465+
1466+
IndexStmt IndexStmt::concretizeScheduled(ProvenanceGraph provGraph) const {
1467+
IndexStmt stmt = *this;
1468+
string r;
1469+
if (isEinsumNotation(stmt)) {
1470+
stmt = makeReductionNotationScheduled(stmt, provGraph);
1471+
}
1472+
if (isReductionNotationScheduled(stmt, provGraph, &r)) {
1473+
stmt = makeConcreteNotationScheduled(stmt, provGraph);
1474+
}
1475+
return stmt;
1476+
}
1477+
14641478
IndexStmt IndexStmt::concretize() const {
14651479
IndexStmt stmt = *this;
14661480
if (isEinsumNotation(stmt)) {
@@ -2322,8 +2336,10 @@ bool isReductionNotation(IndexStmt stmt, std::string* reason) {
23222336
bool isReduction = true;
23232337

23242338
util::ScopedMap<IndexVar,int> boundVars; // (int) value not used
2339+
vector<IndexVar> boundVarsList;
23252340
for (auto& var : to<Assignment>(stmt).getFreeVars()) {
23262341
boundVars.insert({var,0});
2342+
boundVarsList.push_back(var);
23272343
}
23282344

23292345
match(stmt,
@@ -2347,6 +2363,67 @@ bool isReductionNotation(IndexStmt stmt, std::string* reason) {
23472363
return isReduction;
23482364
}
23492365

2366+
bool isReductionNotationScheduled(IndexStmt stmt, ProvenanceGraph provGraph, std::string* reason) {
2367+
INIT_REASON(reason);
2368+
2369+
if (!isa<Assignment>(stmt)) {
2370+
*reason = "reduction notation statements must be assignments";
2371+
return false;
2372+
}
2373+
2374+
if (!isValid(to<Assignment>(stmt), reason)) {
2375+
return false;
2376+
}
2377+
2378+
// Reduction notation until proved otherwise
2379+
bool isReduction = true;
2380+
2381+
util::ScopedMap<IndexVar,int> boundVars; // (int) value not used
2382+
vector<IndexVar> boundVarsList;
2383+
for (auto& var : to<Assignment>(stmt).getFreeVars()) {
2384+
boundVars.insert({var,0});
2385+
boundVarsList.push_back(var);
2386+
}
2387+
2388+
match(stmt,
2389+
std::function<void(const ReductionNode*,Matcher*)>([&](
2390+
const ReductionNode* op, Matcher* ctx) {
2391+
boundVars.scope();
2392+
boundVars.insert({op->var,0});
2393+
ctx->match(op->a);
2394+
boundVars.unscope();
2395+
}),
2396+
std::function<void(const AccessNode*)>([&](const AccessNode* op) {
2397+
for (auto& var : op->indexVars) {
2398+
if (!boundVars.contains(var)) {
2399+
// This detects to see if one of the boundVars is an ancestor of var
2400+
// or if boundVars is a descendant of var given the Provenance Graph.
2401+
// If either of these are true, then the statement is still in reduction notation.
2402+
if (provGraph.isFullyDerived(var)) {
2403+
auto ancestors = provGraph.getUnderivedAncestors(var);
2404+
for (auto& ancestor: ancestors) {
2405+
if (boundVars.contains(ancestor)) {
2406+
return true;
2407+
}
2408+
}
2409+
} else {
2410+
auto descendants = provGraph.getFullyDerivedDescendants(var);
2411+
for (auto& descendant : descendants) {
2412+
if (boundVars.contains(descendant)) {
2413+
return true;
2414+
}
2415+
}
2416+
}
2417+
*reason = "all reduction variables in reduction notation must be "
2418+
"bound by a reduction expression";
2419+
isReduction = false;
2420+
}
2421+
}
2422+
})
2423+
);
2424+
return isReduction;
2425+
}
2426+
23502427
bool isConcreteNotation(IndexStmt stmt, std::string* reason) {
23512428
taco_iassert(stmt.defined()) << "the index statement is undefined";
23522429
INIT_REASON(reason);
@@ -2605,6 +2682,207 @@ IndexStmt makeConcreteNotation(IndexStmt stmt) {
26052682
return stmt;
26062683
}
26072684

2685+
Assignment makeReductionNotationScheduled(Assignment assignment, ProvenanceGraph provGraph) {
2686+
IndexExpr expr = assignment.getRhs();
2687+
std::vector<IndexVar> free = assignment.getLhs().getIndexVars();
2688+
if (!isEinsumNotation(assignment)) {
2689+
return assignment;
2690+
}
2691+
2692+
struct MakeReductionNotation : IndexNotationRewriter {
2693+
MakeReductionNotation(const std::vector<IndexVar>& free, ProvenanceGraph provGraph)
2694+
: free(free.begin(), free.end()), provGraph(provGraph){}
2695+
2696+
ProvenanceGraph provGraph;
2697+
std::set<IndexVar> free;
2698+
bool onlyOneTerm;
2699+
2700+
IndexExpr addReductions(IndexExpr expr) {
2701+
auto vars = getIndexVars(expr);
2702+
for (auto& var : util::reverse(vars)) {
2703+
2704+
if (!util::contains(free, var)) {
2705+
bool shouldReduce = true;
2706+
/// Do not add a reduction node if mismatch is between a fully derived indexVar and its ancestor
2707+
if (provGraph.isFullyDerived(var)) {
2708+
for (auto& f: free) {
2709+
if (provGraph.isDerivedFrom(var, f)) {
2710+
shouldReduce = false;
2711+
}
2712+
}
2713+
} else {
2714+
for (auto& f: free) {
2715+
if (provGraph.isDerivedFrom(f, var)) {
2716+
shouldReduce = false;
2717+
}
2718+
}
2719+
}
2720+
if (shouldReduce)
2721+
expr = sum(var,expr);
2722+
}
2723+
}
2724+
return expr;
2725+
}
2726+
2727+
IndexExpr einsum(const IndexExpr& expr) {
2728+
onlyOneTerm = true;
2729+
IndexExpr einsumexpr = rewrite(expr);
2730+
2731+
if (onlyOneTerm) {
2732+
einsumexpr = addReductions(einsumexpr);
2733+
}
2734+
2735+
return einsumexpr;
2736+
}
2737+
2738+
using IndexNotationRewriter::visit;
2739+
2740+
void visit(const AddNode* op) {
2741+
// Sum every reduction variables over each term
2742+
onlyOneTerm = false;
2743+
2744+
IndexExpr a = addReductions(op->a);
2745+
IndexExpr b = addReductions(op->b);
2746+
if (a == op->a && b == op->b) {
2747+
expr = op;
2748+
}
2749+
else {
2750+
expr = new AddNode(a, b);
2751+
}
2752+
}
2753+
2754+
void visit(const SubNode* op) {
2755+
// Sum every reduction variables over each term
2756+
onlyOneTerm = false;
2757+
2758+
IndexExpr a = addReductions(op->a);
2759+
IndexExpr b = addReductions(op->b);
2760+
if (a == op->a && b == op->b) {
2761+
expr = op;
2762+
}
2763+
else {
2764+
expr = new SubNode(a, b);
2765+
}
2766+
}
2767+
};
2768+
return Assignment(assignment.getLhs(),
2769+
MakeReductionNotation(free, provGraph).einsum(expr),
2770+
assignment.getOperator());
2771+
}
2772+
2773+
IndexStmt makeReductionNotationScheduled(IndexStmt stmt, ProvenanceGraph provGraph) {
2774+
taco_iassert(isEinsumNotation(stmt));
2775+
return makeReductionNotationScheduled(to<Assignment>(stmt), provGraph);
2776+
}
2777+
2778+
IndexStmt makeConcreteNotationScheduled(IndexStmt stmt, ProvenanceGraph provGraph) {
2779+
std::string reason;
2780+
taco_iassert(isReductionNotationScheduled(stmt, provGraph, &reason))
2781+
<< "Not reduction notation: " << stmt << std::endl << reason;
2782+
taco_iassert(isa<Assignment>(stmt));
2783+
2784+
// Free variables and reductions covering the whole rhs become top level loops
2785+
vector<IndexVar> freeVars = to<Assignment>(stmt).getFreeVars();
2786+
2787+
struct RemoveTopLevelReductions : IndexNotationRewriter {
2788+
using IndexNotationRewriter::visit;
2789+
2790+
void visit(const AssignmentNode* node) {
2791+
// Easiest to just walk down the reduction node until we find something
2792+
// that's not a reduction
2793+
vector<IndexVar> topLevelReductions;
2794+
IndexExpr rhs = node->rhs;
2795+
while (isa<Reduction>(rhs)) {
2796+
Reduction reduction = to<Reduction>(rhs);
2797+
topLevelReductions.push_back(reduction.getVar());
2798+
rhs = reduction.getExpr();
2799+
}
2800+
2801+
if (rhs != node->rhs) {
2802+
stmt = Assignment(node->lhs, rhs, Add());
2803+
for (auto& i : util::reverse(topLevelReductions)) {
2804+
stmt = forall(i, stmt);
2805+
}
2806+
}
2807+
else {
2808+
stmt = node;
2809+
}
2810+
}
2811+
};
2812+
stmt = RemoveTopLevelReductions().rewrite(stmt);
2813+
2814+
// This gets the list of indexVars on the rhs of an assignment
2815+
// TODO: check to make sure that we want to get ALL rhs indexVars (not just the upper level)
2816+
vector<IndexVar> rhsVars;
2817+
match(stmt,
2818+
function<void(const AccessNode*, Matcher*)>([&](const AccessNode* op, Matcher* ctx) {
2819+
rhsVars.insert(rhsVars.end(), op->indexVars.begin(), op->indexVars.end());
2820+
}),
2821+
function<void(const AssignmentNode*, Matcher*)>([&](const AssignmentNode* op, Matcher* ctx) {
2822+
ctx->match(op->rhs);
2823+
})
2824+
);
2825+
2826+
// Emit the freeVars as foralls if the freeVars are fully derived
2827+
// else emit the fully derived descendant of the freeVar found in rhsVars
2828+
for (auto& i : util::reverse(freeVars)) {
2829+
if (provGraph.isFullyDerived(i))
2830+
stmt = forall(i, stmt);
2831+
else {
2832+
auto derivedVars = provGraph.getFullyDerivedDescendants(i);
2833+
IndexVar derivedI = *rhsVars.begin();
2834+
for (auto& derivedVar : derivedVars) {
2835+
if (std::find(rhsVars.begin(), rhsVars.end(), derivedVar) != rhsVars.end()) {
2836+
derivedI = derivedVar;
2837+
}
2838+
}
2839+
stmt = forall(derivedI, stmt);
2840+
}
2841+
}
2842+
2843+
// Replace other reductions with where and forall statements
2844+
struct ReplaceReductionsWithWheres : IndexNotationRewriter {
2845+
using IndexNotationRewriter::visit;
2846+
2847+
Reduction reduction;
2848+
TensorVar t;
2849+
2850+
void visit(const AssignmentNode* node) {
2851+
reduction = Reduction();
2852+
t = TensorVar();
2853+
2854+
IndexExpr rhs = rewrite(node->rhs);
2855+
2856+
// nothing was rewritten
2857+
if (rhs == node->rhs) {
2858+
stmt = node;
2859+
return;
2860+
}
2861+
2862+
taco_iassert(t.defined() && reduction.defined());
2863+
IndexStmt consumer = Assignment(node->lhs, rhs, node->op);
2864+
IndexStmt producer = forall(reduction.getVar(),
2865+
Assignment(t, reduction.getExpr(),
2866+
reduction.getOp()));
2867+
stmt = where(rewrite(consumer), rewrite(producer));
2868+
}
2869+
2870+
void visit(const ReductionNode* node) {
2871+
// only rewrite one reduction at a time
2872+
if (reduction.defined()) {
2873+
expr = node;
2874+
return;
2875+
}
2876+
2877+
reduction = node;
2878+
t = TensorVar("t" + util::toString(node->var),
2879+
node->getDataType());
2880+
expr = t;
2881+
}
2882+
};
2883+
stmt = ReplaceReductionsWithWheres().rewrite(stmt);
2884+
return stmt;
2885+
}
26082886

26092887
vector<TensorVar> getResults(IndexStmt stmt) {
26102888
vector<TensorVar> result;

0 commit comments

Comments
 (0)