Skip to content

Commit 3ef1af0

Browse files
committed
Change precompute(...) algorithm to use workspaces paper algorithm
1 parent aa57ebc commit 3ef1af0

File tree

5 files changed

+264
-77
lines changed

5 files changed

+264
-77
lines changed

include/taco/index_notation/index_notation.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -551,7 +551,7 @@ class IndexStmt : public util::IntrusivePtr<const IndexStmtNode> {
551551

552552
/// Takes any index notation and concretizes unknowns to make it concrete notation
553553
/// given a Provenance Graph of indexVars
554-
IndexStmt concretizeScheduled(ProvenanceGraph provGraph) const;
554+
IndexStmt concretizeScheduled(ProvenanceGraph provGraph, std::vector<IndexVar> forallIndexVarList) const;
555555

556556
/// The \code{split} transformation splits (strip-mines) an index
557557
/// variable into two nested index variables, where the size of the
@@ -1162,7 +1162,7 @@ IndexStmt makeReductionNotationScheduled(IndexStmt, ProvenanceGraph);
11621162
/// Convert reduction notation to concrete notation, by inserting forall nodes,
11631163
/// replacing reduction nodes by compound assignments, and inserting temporaries
11641164
/// as needed while taking into account a schedule given by the Provenance Graph.
1165-
IndexStmt makeConcreteNotationScheduled(IndexStmt, ProvenanceGraph);
1165+
IndexStmt makeConcreteNotationScheduled(IndexStmt, ProvenanceGraph, std::vector<IndexVar> forallIndexVars);
11661166

11671167
/// Returns the results of the index statement, in the order they appear.
11681168
std::vector<TensorVar> getResults(IndexStmt stmt);

src/error/error_checks.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@ std::pair<bool, string> dimensionsTypecheck(const std::vector<IndexVar>& resultV
4141
for (size_t mode = 0; mode < resultVars.size(); mode++) {
4242
IndexVar var = resultVars[mode];
4343
auto dimension = shape.getDimension(mode);
44-
if (util::contains(indexVarDims,var) && indexVarDims.at(var) != dimension) {
44+
if (util::contains(indexVarDims,var) && indexVarDims.at(var) != dimension &&
45+
!(indexVarDims.at(var).isIndexVarSized() && indexVarDims.at(var).getIndexVarSize() == var) &&
46+
!(dimension.isIndexVarSized() && dimension.getIndexVarSize() == var)) {
4547
errors.push_back(addDimensionError(var, indexVarDims.at(var), dimension));
4648
} else {
4749
indexVarDims.insert({var, dimension});
@@ -63,7 +65,9 @@ std::pair<bool, string> dimensionsTypecheck(const std::vector<IndexVar>& resultV
6365
dimension = Dimension(a.getIndexSet(mode).size());
6466
}
6567

66-
if (util::contains(indexVarDims,var) && indexVarDims.at(var) != dimension) {
68+
if (util::contains(indexVarDims,var) && indexVarDims.at(var) != dimension &&
69+
!(indexVarDims.at(var).isIndexVarSized() && indexVarDims.at(var).getIndexVarSize() == var) &&
70+
!(dimension.isIndexVarSized() && dimension.getIndexVarSize() == var)) {
6771
errors.push_back(addDimensionError(var, indexVarDims.at(var), dimension));
6872
} else {
6973
indexVarDims.insert({var, dimension});

src/index_notation/index_notation.cpp

Lines changed: 71 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1463,14 +1463,18 @@ map<IndexVar,Dimension> IndexStmt::getIndexVarDomains() const {
14631463

14641464

14651465

1466-
IndexStmt IndexStmt::concretizeScheduled(ProvenanceGraph provGraph) const {
1466+
IndexStmt IndexStmt::concretizeScheduled(ProvenanceGraph provGraph, vector<IndexVar> forallIndexVarList) const {
14671467
IndexStmt stmt = *this;
14681468
string r;
1469-
if (isEinsumNotation(stmt)) {
1469+
cout << "Pre concretized stmt: " << stmt << endl;
1470+
if (isEinsumNotation(stmt, &r)) {
14701471
stmt = makeReductionNotationScheduled(stmt, provGraph);
1472+
cout << "Post Reduction Stmt: " << stmt << endl;
14711473
}
1474+
cout << r << endl;
14721475
if (isReductionNotationScheduled(stmt, provGraph, &r)) {
1473-
stmt = makeConcreteNotationScheduled(stmt, provGraph);
1476+
stmt = makeConcreteNotationScheduled(stmt, provGraph, forallIndexVarList);
1477+
cout << "Post Concretize Stmt: " << stmt << endl;
14741478
}
14751479
return stmt;
14761480
}
@@ -2775,17 +2779,22 @@ IndexStmt makeReductionNotationScheduled(IndexStmt stmt, ProvenanceGraph provGra
27752779
return makeReductionNotationScheduled(to<Assignment>(stmt), provGraph);
27762780
}
27772781

2778-
IndexStmt makeConcreteNotationScheduled(IndexStmt stmt, ProvenanceGraph provGraph) {
2782+
IndexStmt makeConcreteNotationScheduled(IndexStmt stmt, ProvenanceGraph provGraph, vector<IndexVar> forallIndexVars) {
27792783
std::string reason;
27802784
taco_iassert(isReductionNotationScheduled(stmt, provGraph, &reason))
27812785
<< "Not reduction notation: " << stmt << std::endl << reason;
27822786
taco_iassert(isa<Assignment>(stmt));
27832787

27842788
// Free variables and reductions covering the whole rhs become top level loops
27852789
vector<IndexVar> freeVars = to<Assignment>(stmt).getFreeVars();
2790+
vector<IndexVar> reductionAndFreeVars;
27862791

27872792
struct RemoveTopLevelReductions : IndexNotationRewriter {
27882793
using IndexNotationRewriter::visit;
2794+
vector<IndexVar> forallIndexVars;
2795+
vector<IndexVar> reductionAndFreeVars;
2796+
2797+
RemoveTopLevelReductions(vector<IndexVar> forallIndexVars) : forallIndexVars(forallIndexVars) {}
27892798

27902799
void visit(const AssignmentNode* node) {
27912800
// Easiest to just walk down the reduction node until we find something
@@ -2800,43 +2809,86 @@ IndexStmt makeConcreteNotationScheduled(IndexStmt stmt, ProvenanceGraph provGrap
28002809

28012810
if (rhs != node->rhs) {
28022811
stmt = Assignment(node->lhs, rhs, Add());
2803-
for (auto& i : util::reverse(topLevelReductions)) {
2804-
stmt = forall(i, stmt);
2812+
if (forallIndexVars.empty()) {
2813+
for (auto &i : util::reverse(topLevelReductions)) {
2814+
stmt = forall(i, stmt);
2815+
}
2816+
} else {
2817+
reductionAndFreeVars.insert(reductionAndFreeVars.end(), topLevelReductions.begin(), topLevelReductions.end());
28052818
}
28062819
}
28072820
else {
28082821
stmt = node;
28092822
}
28102823
}
28112824
};
2812-
stmt = RemoveTopLevelReductions().rewrite(stmt);
2813-
2825+
auto rewriter = RemoveTopLevelReductions(forallIndexVars);
2826+
stmt = rewriter.rewrite(stmt);
2827+
reductionAndFreeVars = rewriter.reductionAndFreeVars;
28142828
// This gets the list of indexVars on the rhs of an assignment
28152829
// TODO: check to make sure that we want to get ALL rhs indexVars (not just the upper level)
28162830
vector<IndexVar> rhsVars;
28172831
match(stmt,
28182832
function<void(const AccessNode*, Matcher*)>([&](const AccessNode* op, Matcher* ctx) {
2819-
rhsVars.insert(rhsVars.end(), op->indexVars.begin(), op->indexVars.end());
2833+
for (auto &i : op->indexVars) {
2834+
if (std::find(rhsVars.begin(), rhsVars.end(), i) == rhsVars.end()) {
2835+
rhsVars.push_back(i);
2836+
}
2837+
}
28202838
}),
28212839
function<void(const AssignmentNode*, Matcher*)>([&](const AssignmentNode* op, Matcher* ctx) {
28222840
ctx->match(op->rhs);
28232841
})
28242842
);
28252843

2844+
cout << "freeVars: ";
2845+
for (auto &i : freeVars) {
2846+
cout << i << ", ";
2847+
}
2848+
cout << endl;
2849+
2850+
cout << "rhsVars: ";
2851+
for (auto &i : rhsVars) {
2852+
cout << i << ", ";
2853+
}
2854+
cout << endl;
2855+
2856+
cout << "forallIndexVars: ";
2857+
for (auto &i : forallIndexVars) {
2858+
cout << i << ", ";
2859+
}
2860+
cout << endl;
28262861
// Emit the freeVars as foralls if the freeVars are fully derived
28272862
// else emit the fully derived descendant of the freeVar found in rhsVars
2828-
for (auto& i : util::reverse(freeVars)) {
2829-
if (provGraph.isFullyDerived(i))
2830-
stmt = forall(i, stmt);
2831-
else {
2832-
auto derivedVars = provGraph.getFullyDerivedDescendants(i);
2833-
IndexVar derivedI = *rhsVars.begin();
2834-
for (auto& derivedVar : derivedVars) {
2835-
if (std::find(rhsVars.begin(), rhsVars.end(), derivedVar) != rhsVars.end()) {
2836-
derivedI = derivedVar;
2863+
if (forallIndexVars.empty()) {
2864+
for (auto &i : util::reverse(freeVars)) {
2865+
if (provGraph.isFullyDerived(i))
2866+
stmt = forall(i, stmt);
2867+
else {
2868+
auto derivedVars = provGraph.getFullyDerivedDescendants(i);
2869+
IndexVar derivedI = *rhsVars.begin();
2870+
for (auto &derivedVar : derivedVars) {
2871+
if (std::find(rhsVars.begin(), rhsVars.end(), derivedVar) != rhsVars.end()) {
2872+
derivedI = derivedVar;
2873+
}
2874+
}
2875+
stmt = forall(derivedI, stmt);
2876+
}
2877+
}
2878+
} else {
2879+
reductionAndFreeVars.insert(reductionAndFreeVars.end(), freeVars.begin(), freeVars.end());
2880+
for (auto &i : util::reverse(forallIndexVars)) {
2881+
if (std::find(reductionAndFreeVars.begin(), reductionAndFreeVars.end(), i) != reductionAndFreeVars.end())
2882+
stmt = forall(i, stmt);
2883+
else {
2884+
auto ancestorVars = provGraph.getUnderivedAncestors(i);
2885+
IndexVar ancestorI = *reductionAndFreeVars.begin();
2886+
for (auto &ancestorVar : ancestorVars) {
2887+
if (std::find(reductionAndFreeVars.begin(), reductionAndFreeVars.end(), ancestorVar) != reductionAndFreeVars.end()) {
2888+
stmt = forall(i, stmt);
2889+
}
28372890
}
28382891
}
2839-
stmt = forall(derivedI, stmt);
28402892
}
28412893
}
28422894

0 commit comments

Comments
 (0)