Skip to content

Commit 392fc66

Browse files
committed
Merge remote-tracking branch 'origin' into multidim-workspace
2 parents d9707c8 + d464c7c commit 392fc66

File tree

9 files changed

+4286
-4066
lines changed

9 files changed

+4286
-4066
lines changed

include/taco/index_notation/provenance_graph.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,9 @@ class ProvenanceGraph {
354354
/// Node is recoverable if children appear in defined
355355
bool isRecoverable(IndexVar indexVar, std::set<IndexVar> defined) const;
356356

357+
/// isRecoverable helper method to handle precompute relations and where statements in the provenance graph
358+
bool isRecoverablePrecompute(IndexVar indexVar, std::set<IndexVar> defined, std::vector<IndexVar> producers, std::vector<IndexVar> consumers) const;
359+
357360
/// Node is recoverable if at most 1 unknown variable in relationship (parents + siblings)
358361
bool isChildRecoverable(taco::IndexVar indexVar, std::set<taco::IndexVar> defined) const;
359362

include/taco/lower/lowerer_impl.h

Lines changed: 22 additions & 426 deletions
Large diffs are not rendered by default.

include/taco/lower/lowerer_impl_imperative.h

Lines changed: 533 additions & 0 deletions
Large diffs are not rendered by default.

src/index_notation/index_notation.cpp

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1534,15 +1534,18 @@ IndexStmt IndexStmt::precompute(IndexExpr expr, std::vector<IndexVar> i_vars,
15341534
// TODO: need to assert they are same length
15351535
IndexStmt transformed = *this;
15361536
string reason;
1537-
// if (i != iw) {
1538-
// IndexVarRel rel = IndexVarRel(new PrecomputeRelNode(i, iw));
1539-
// transformed = Transformation(AddSuchThatPredicates({rel})).apply(transformed, &reason);
1540-
// if (!transformed.defined()) {
1541-
// taco_uerror << reason;
1542-
// }
1543-
// }
1537+
1538+
// FIXME: need to re-enable this later
1539+
// if (i != iw) {
1540+
// IndexVarRel rel = IndexVarRel(new PrecomputeRelNode(i, iw));
1541+
// transformed = Transformation(AddSuchThatPredicates({rel})).apply(transformed, &reason);
1542+
// if (!transformed.defined()) {
1543+
// taco_uerror << reason;
1544+
// }
1545+
// }
15441546

15451547
transformed = Transformation(Precompute(expr, i_vars, iw_vars, workspace)).apply(transformed, &reason);
1548+
15461549
if (!transformed.defined()) {
15471550
taco_uerror << reason;
15481551
}

src/index_notation/provenance_graph.cpp

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1120,8 +1120,39 @@ bool ProvenanceGraph::isAvailable(IndexVar indexVar, std::set<IndexVar> defined)
11201120

11211121
bool ProvenanceGraph::isRecoverable(taco::IndexVar indexVar, std::set<taco::IndexVar> defined) const {
11221122
// all children are either defined or recoverable from their children
1123+
// This checks the definedVars list to determine where in the statement the variables are trying to be
1124+
// recovered from ( either on the producer or consumer side of a where stmt or not in a where stmt)
1125+
vector<IndexVar> producers;
1126+
vector<IndexVar> consumers;
1127+
for (auto& def : defined) {
1128+
if (childRelMap.count(def) && childRelMap.at(def).getRelType() == IndexVarRelType::PRECOMPUTE) {
1129+
consumers.push_back(def);
1130+
}
1131+
if (parentRelMap.count(def) && parentRelMap.at(def).getRelType() == IndexVarRelType::PRECOMPUTE) {
1132+
producers.push_back(def);
1133+
}
1134+
}
1135+
1136+
return isRecoverablePrecompute(indexVar, defined, producers, consumers);
1137+
}
1138+
1139+
bool ProvenanceGraph::isRecoverablePrecompute(taco::IndexVar indexVar, std::set<taco::IndexVar> defined,
1140+
vector<IndexVar> producers, vector<IndexVar> consumers) const {
1141+
vector<IndexVar> childPrecompute;
1142+
if (std::find(consumers.begin(), consumers.end(), indexVar) != consumers.end()) {
1143+
return true;
1144+
}
1145+
if (!producers.empty() && (childRelMap.count(indexVar) &&
1146+
childRelMap.at(indexVar).getRelType() == IndexVarRelType::PRECOMPUTE)) {
1147+
auto precomputeChild = getChildren(indexVar)[0];
1148+
if (std::find(producers.begin(), producers.end(), precomputeChild) != producers.end()) {
1149+
return true;
1150+
}
1151+
return isRecoverablePrecompute(precomputeChild, defined, producers, consumers);
1152+
}
11231153
for (const IndexVar& child : getChildren(indexVar)) {
1124-
if (!defined.count(child) && (isFullyDerived(child) || !isRecoverable(child, defined))) {
1154+
if (!defined.count(child) && (isFullyDerived(child) ||
1155+
!isRecoverablePrecompute(child, defined, producers, consumers))) {
11251156
return false;
11261157
}
11271158
}

src/lower/iterator.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,15 @@ Iterators::Iterators(IndexStmt stmt, const map<TensorVar, Expr>& tensorVars)
531531
underivedAdded.insert(underived);
532532
}
533533
}
534+
535+
// Insert all children of current index variable into iterators as well
536+
for (const IndexVar& child : provGraph.getChildren(n->indexVar)) {
537+
if (!underivedAdded.count(child)) {
538+
content->modeIterators.insert({child, child});
539+
underivedAdded.insert(child);
540+
}
541+
}
542+
534543
m->match(n->stmt);
535544
})
536545
);

src/lower/lower.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "taco/ir/ir_printer.h"
1616

1717
#include "taco/lower/lowerer_impl.h"
18+
#include "taco/lower/lowerer_impl_imperative.h"
1819
#include "taco/lower/iterator.h"
1920
#include "mode_access.h"
2021

@@ -33,7 +34,7 @@ namespace taco {
3334

3435

3536
// class Lowerer
36-
Lowerer::Lowerer() : impl(new LowererImpl()) {
37+
Lowerer::Lowerer() : impl(new LowererImplImperative()) {
3738
}
3839

3940
Lowerer::Lowerer(LowererImpl* impl) : impl(impl) {

0 commit comments

Comments
 (0)