Skip to content

Commit 6bf8f0f

Browse files
committed
Fix isRecoverable helper methods to handle precompute relation
1 parent 23448cd commit 6bf8f0f

File tree

3 files changed

+18
-88
lines changed

3 files changed

+18
-88
lines changed

include/taco/index_notation/provenance_graph.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -353,9 +353,9 @@ class ProvenanceGraph {
353353

354354
/// Node is recoverable if children appear in defined
355355
bool isRecoverable(IndexVar indexVar, std::set<IndexVar> defined) const;
356-
bool isRecoverableFull(IndexVar indexVar, std::set<IndexVar> defined) const;
357-
bool isRecoverableProducer(IndexVar indexVar, std::set<IndexVar> defined) const;
358-
bool isRecoverableConsumer(IndexVar indexVar, std::set<IndexVar> defined) const;
356+
357+
/// isRecoverable helper method to handle precompute relations and where statements in the provenance graph
358+
bool isRecoverablePrecompute(IndexVar indexVar, std::set<IndexVar> defined, std::vector<IndexVar> producers, std::vector<IndexVar> consumers) const;
359359

360360
/// Node is recoverable if at most 1 unknown variable in relationship (parents + siblings)
361361
bool isChildRecoverable(taco::IndexVar indexVar, std::set<taco::IndexVar> defined) const;

src/index_notation/provenance_graph.cpp

Lines changed: 15 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1122,62 +1122,35 @@ bool ProvenanceGraph::isRecoverable(taco::IndexVar indexVar, std::set<taco::Inde
11221122
// all children are either defined or recoverable from their children
11231123
// This checks the definedVars list to determine where in the statement the variables are trying to be
11241124
// recovered from ( either on the producer or consumer side of a where stmt or not in a where stmt)
1125-
bool producer = false;
1126-
bool consumer = false;
1125+
vector<IndexVar> producers;
1126+
vector<IndexVar> consumers;
11271127
for (auto& def : defined) {
11281128
if (childRelMap.count(def) && childRelMap.at(def).getRelType() == IndexVarRelType::PRECOMPUTE) {
1129-
consumer = true;
1129+
consumers.push_back(def);
11301130
}
11311131
if (parentRelMap.count(def) && parentRelMap.at(def).getRelType() == IndexVarRelType::PRECOMPUTE) {
1132-
producer = true;
1132+
producers.push_back(def);
11331133
}
11341134
}
1135-
if (producer) {
1136-
return isRecoverableProducer(indexVar, defined);
1137-
}
1138-
else if (consumer) {
1139-
return isRecoverableConsumer(indexVar, defined);
1140-
} else {
1141-
return isRecoverableFull(indexVar, defined);
1142-
}
1143-
}
11441135

1145-
bool ProvenanceGraph::isRecoverableFull(taco::IndexVar indexVar, std::set<taco::IndexVar> defined) const {
1146-
// all children are either defined or recoverable from their children
1147-
// precompute relations are treated as normal relations
1148-
for (const IndexVar& child : getChildren(indexVar)) {
1149-
if (!defined.count(child) && (isFullyDerived(child) || !isRecoverableFull(child, defined))) {
1150-
return false;
1151-
}
1152-
}
1153-
return true;
1136+
return isRecoverablePrecompute(indexVar, defined, producers, consumers);
11541137
}
11551138

1156-
bool ProvenanceGraph::isRecoverableConsumer(taco::IndexVar indexVar, std::set<taco::IndexVar> defined) const {
1139+
bool ProvenanceGraph::isRecoverablePrecompute(taco::IndexVar indexVar, std::set<taco::IndexVar> defined, vector<IndexVar> producers, vector<IndexVar> consumers) const {
11571140
vector<IndexVar> childPrecompute;
1158-
if (childRelMap.count(indexVar) && childRelMap.at(indexVar).getRelType() == IndexVarRelType::PRECOMPUTE) {
1159-
childPrecompute = childrenMap.at(indexVar);
1141+
if (std::find(consumers.begin(), consumers.end(), indexVar) != consumers.end()) {
1142+
return true;
11601143
}
1161-
for (const IndexVar& child : getChildren(indexVar)) {
1162-
if (!childPrecompute.empty() && child == childPrecompute[0]) continue;
1163-
if (!defined.count(child) && (isFullyDerived(child) ||
1164-
(childRelMap.count(child) && childRelMap.at(child).getRelType() == IndexVarRelType::PRECOMPUTE) || !isRecoverableConsumer(child, defined))) {
1165-
return false;
1144+
if (!producers.empty() && (childRelMap.count(indexVar) && childRelMap.at(indexVar).getRelType() == IndexVarRelType::PRECOMPUTE)) {
1145+
auto precomputeChild = getChildren(indexVar)[0];
1146+
if (std::find(producers.begin(), producers.end(), precomputeChild) != producers.end()) {
1147+
return true;
11661148
}
1149+
return isRecoverablePrecompute(precomputeChild, defined, producers, consumers);
11671150
}
1168-
return true;
1169-
}
1170-
1171-
bool ProvenanceGraph::isRecoverableProducer(taco::IndexVar indexVar, std::set<taco::IndexVar> defined) const {
1172-
vector<IndexVar> childPrecompute;
11731151
for (const IndexVar& child : getChildren(indexVar)) {
1174-
if (childRelMap.count(child) && childRelMap.at(child).getRelType() == IndexVarRelType::PRECOMPUTE) {
1175-
auto precomputeChild = childrenMap.at(child)[0];
1176-
if (!defined.count(precomputeChild) && (isFullyDerived(precomputeChild) || !isRecoverableProducer(precomputeChild, defined))) {
1177-
return false;
1178-
}
1179-
}
1180-
else if (!defined.count(child) && (isFullyDerived(child) || !isRecoverableProducer(child, defined))) {
1152+
if (!defined.count(child) && (isFullyDerived(child)
1153+
|| !isRecoverablePrecompute(child, defined, producers, consumers))) {
11811154
return false;
11821155
}
11831156
}

test/tests-workspaces.cpp

Lines changed: 0 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -187,46 +187,3 @@ TEST(workspaces, tile_denseMatMul) {
187187

188188
}
189189

190-
TEST(DISABLED_workspaces, multiplePrecomputeIndependentIndexVarsSplit) {
191-
192-
Tensor<double> A("A", {16}, Format{Dense});
193-
Tensor<double> B("B", {16}, Format{Dense});
194-
Tensor<double> C("C", {16}, Format{Dense});
195-
Tensor<double> D("D", {16}, Format{Dense});
196-
197-
for (int i = 0; i < 16; i++) {
198-
B.insert({i}, (double) i);
199-
C.insert({i}, (double) i);
200-
D.insert({i}, (double) i);
201-
}
202-
203-
IndexVar i("i");
204-
IndexVar iw1("iw1");
205-
IndexVar iw2("iw2");
206-
IndexVar iw2_outter("iw2_outer");
207-
IndexVar iw2_inner("iw2_inner");
208-
A(i) = B(i) + C(i) + D(i);
209-
210-
// Precompute then split iw tensor
211-
IndexStmt stmt = A.getAssignment().concretize();
212-
TensorVar precomputed1("precomputed1", Type(Float64, {16}), taco::dense);
213-
TensorVar precomputed2("precomputed2", Type(Float64, {16}), taco::dense);
214-
stmt = stmt.precompute(A.getAssignment().getRhs(), i, iw1, precomputed1);
215-
cout << stmt.concretize() << endl;
216-
stmt = stmt.precompute(B(iw1)+C(iw1), iw1, iw2, precomputed2);
217-
//.split(iw2,iw2_outter, iw2_inner, 8);
218-
219-
cout << stmt.concretize() << endl;
220-
A.compile(stmt.concretize());
221-
A.assemble();
222-
A.compute();
223-
224-
Tensor<double> expected("expected", {16}, Format{Dense});
225-
expected(i) = B(i) + C(i);
226-
expected.compile();
227-
expected.assemble();
228-
expected.compute();
229-
230-
ASSERT_TENSOR_EQ(A, expected);
231-
}
232-

0 commit comments

Comments
 (0)