Skip to content

Commit 23448cd

Browse files
committed
Add in fixes for isRecoverable() based on if there is a where stmt
1 parent 412483c commit 23448cd

File tree

5 files changed

+113
-7
lines changed

5 files changed

+113
-7
lines changed

include/taco/index_notation/provenance_graph.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,9 @@ class ProvenanceGraph {
353353

354354
/// Node is recoverable if children appear in defined
355355
bool isRecoverable(IndexVar indexVar, std::set<IndexVar> defined) const;
356+
bool isRecoverableFull(IndexVar indexVar, std::set<IndexVar> defined) const;
357+
bool isRecoverableProducer(IndexVar indexVar, std::set<IndexVar> defined) const;
358+
bool isRecoverableConsumer(IndexVar indexVar, std::set<IndexVar> defined) const;
356359

357360
/// Node is recoverable if at most 1 unknown variable in relationship (parents + siblings)
358361
bool isChildRecoverable(taco::IndexVar indexVar, std::set<taco::IndexVar> defined) const;

src/index_notation/index_notation.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2366,8 +2366,12 @@ bool isConcreteNotation(IndexStmt stmt, std::string* reason) {
23662366
for (auto& var : op->indexVars) {
23672367
// non underived variables may appear in temporaries, but we don't check these
23682368
if (!boundVars.contains(var) && provGraph.isUnderived(var) && (provGraph.isFullyDerived(var) || !provGraph.isRecoverable(var, definedVars))) {
2369+
string string2 = "definedvars: ";
2370+
for (auto& d : definedVars)
2371+
string2.append(d.getName() + ", ");
2372+
23692373
*reason = "all variables in concrete notation must be bound by a "
2370-
"forall statement";
2374+
"forall statement" + var.getName() + string2;
23712375
isConcrete = false;
23722376
}
23732377
}

src/index_notation/provenance_graph.cpp

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1120,13 +1120,66 @@ bool ProvenanceGraph::isAvailable(IndexVar indexVar, std::set<IndexVar> defined)
11201120

11211121
bool ProvenanceGraph::isRecoverable(taco::IndexVar indexVar, std::set<taco::IndexVar> defined) const {
11221122
// all children are either defined or recoverable from their children
1123-
// precompute relations are always recoverable since their children never appear in the same loop
1124-
if (!(childRelMap.count(indexVar) && childRelMap.at(indexVar).getRelType() == IndexVarRelType::PRECOMPUTE)) {
1123+
// This checks the definedVars list to determine where in the statement the variables are trying to be
1124+
// recovered from ( either on the producer or consumer side of a where stmt or not in a where stmt)
1125+
bool producer = false;
1126+
bool consumer = false;
1127+
for (auto& def : defined) {
1128+
if (childRelMap.count(def) && childRelMap.at(def).getRelType() == IndexVarRelType::PRECOMPUTE) {
1129+
consumer = true;
1130+
}
1131+
if (parentRelMap.count(def) && parentRelMap.at(def).getRelType() == IndexVarRelType::PRECOMPUTE) {
1132+
producer = true;
1133+
}
1134+
}
1135+
if (producer) {
1136+
return isRecoverableProducer(indexVar, defined);
1137+
}
1138+
else if (consumer) {
1139+
return isRecoverableConsumer(indexVar, defined);
1140+
} else {
1141+
return isRecoverableFull(indexVar, defined);
1142+
}
1143+
}
1144+
1145+
bool ProvenanceGraph::isRecoverableFull(taco::IndexVar indexVar, std::set<taco::IndexVar> defined) const {
1146+
// all children are either defined or recoverable from their children
1147+
// precompute relations are treated as normal relations
11251148
for (const IndexVar& child : getChildren(indexVar)) {
1126-
if (!defined.count(child) && (isFullyDerived(child) || !isRecoverable(child, defined))) {
1149+
if (!defined.count(child) && (isFullyDerived(child) || !isRecoverableFull(child, defined))) {
1150+
return false;
1151+
}
1152+
}
1153+
return true;
1154+
}
1155+
1156+
bool ProvenanceGraph::isRecoverableConsumer(taco::IndexVar indexVar, std::set<taco::IndexVar> defined) const {
1157+
vector<IndexVar> childPrecompute;
1158+
if (childRelMap.count(indexVar) && childRelMap.at(indexVar).getRelType() == IndexVarRelType::PRECOMPUTE) {
1159+
childPrecompute = childrenMap.at(indexVar);
1160+
}
1161+
for (const IndexVar& child : getChildren(indexVar)) {
1162+
if (!childPrecompute.empty() && child == childPrecompute[0]) continue;
1163+
if (!defined.count(child) && (isFullyDerived(child) ||
1164+
(childRelMap.count(child) && childRelMap.at(child).getRelType() == IndexVarRelType::PRECOMPUTE) || !isRecoverableConsumer(child, defined))) {
1165+
return false;
1166+
}
1167+
}
1168+
return true;
1169+
}
1170+
1171+
bool ProvenanceGraph::isRecoverableProducer(taco::IndexVar indexVar, std::set<taco::IndexVar> defined) const {
1172+
vector<IndexVar> childPrecompute;
1173+
for (const IndexVar& child : getChildren(indexVar)) {
1174+
if (childRelMap.count(child) && childRelMap.at(child).getRelType() == IndexVarRelType::PRECOMPUTE) {
1175+
auto precomputeChild = childrenMap.at(child)[0];
1176+
if (!defined.count(precomputeChild) && (isFullyDerived(precomputeChild) || !isRecoverableProducer(precomputeChild, defined))) {
11271177
return false;
11281178
}
11291179
}
1180+
else if (!defined.count(child) && (isFullyDerived(child) || !isRecoverableProducer(child, defined))) {
1181+
return false;
1182+
}
11301183
}
11311184
return true;
11321185
}
@@ -1289,6 +1342,12 @@ bool ProvenanceGraph::hasExactBound(IndexVar indexVar) const {
12891342
{
12901343
return rel.getNode<BoundRelNode>()->getBoundType() == BoundType::MaxExact;
12911344
}
1345+
// else if (rel.getRelType() == SPLIT)
1346+
// {
1347+
// return rel.getNode<SplitRelNode>()->getInnerVar() == indexVar;
1348+
// } else if (rel.getRelType() == PRECOMPUTE) {
1349+
// return hasExactBound(rel.getNode<PrecomputeRelNode>()->getParentVar());
1350+
// }
12921351
// TODO: include non-irregular variables
12931352
return false;
12941353
}

src/lower/lowerer_impl_imperative.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -840,7 +840,6 @@ Stmt LowererImplImperative::lowerForall(Forall forall)
840840

841841
Stmt LowererImplImperative::lowerForallCloned(Forall forall) {
842842
// want to emit guards outside of loop to prevent unstructured loop exits
843-
844843
// construct guard
845844
// underived or pos variables that have a descendant that has not been defined yet
846845
vector<IndexVar> varsWithGuard;
@@ -858,7 +857,6 @@ Stmt LowererImplImperative::lowerForallCloned(Forall forall) {
858857
}
859858
}
860859
}
861-
862860
// determine min and max values for vars given already defined variables.
863861
// we do a recovery where we fill in undefined variables with either 0's or the max of their iteration
864862
std::map<IndexVar, Expr> minVarValues;
@@ -903,7 +901,6 @@ Stmt LowererImplImperative::lowerForallCloned(Forall forall) {
903901
minVarValues[var] = provGraph.recoverVariable(var, currentDefinedVarOrder, underivedBounds, minChildValues, iterators);
904902
maxVarValues[var] = provGraph.recoverVariable(var, currentDefinedVarOrder, underivedBounds, maxChildValues, iterators);
905903
}
906-
907904
// Build guards
908905
Expr guardCondition;
909906
for (auto var : varsWithGuard) {

test/tests-workspaces.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,3 +187,46 @@ TEST(workspaces, tile_denseMatMul) {
187187

188188
}
189189

190+
TEST(DISABLED_workspaces, multiplePrecomputeIndependentIndexVarsSplit) {
191+
192+
Tensor<double> A("A", {16}, Format{Dense});
193+
Tensor<double> B("B", {16}, Format{Dense});
194+
Tensor<double> C("C", {16}, Format{Dense});
195+
Tensor<double> D("D", {16}, Format{Dense});
196+
197+
for (int i = 0; i < 16; i++) {
198+
B.insert({i}, (double) i);
199+
C.insert({i}, (double) i);
200+
D.insert({i}, (double) i);
201+
}
202+
203+
IndexVar i("i");
204+
IndexVar iw1("iw1");
205+
IndexVar iw2("iw2");
206+
IndexVar iw2_outter("iw2_outer");
207+
IndexVar iw2_inner("iw2_inner");
208+
A(i) = B(i) + C(i) + D(i);
209+
210+
// Precompute then split iw tensor
211+
IndexStmt stmt = A.getAssignment().concretize();
212+
TensorVar precomputed1("precomputed1", Type(Float64, {16}), taco::dense);
213+
TensorVar precomputed2("precomputed2", Type(Float64, {16}), taco::dense);
214+
stmt = stmt.precompute(A.getAssignment().getRhs(), i, iw1, precomputed1);
215+
cout << stmt.concretize() << endl;
216+
stmt = stmt.precompute(B(iw1)+C(iw1), iw1, iw2, precomputed2);
217+
//.split(iw2,iw2_outter, iw2_inner, 8);
218+
219+
cout << stmt.concretize() << endl;
220+
A.compile(stmt.concretize());
221+
A.assemble();
222+
A.compute();
223+
224+
Tensor<double> expected("expected", {16}, Format{Dense});
225+
expected(i) = B(i) + C(i);
226+
expected.compile();
227+
expected.assemble();
228+
expected.compute();
229+
230+
ASSERT_TENSOR_EQ(A, expected);
231+
}
232+

0 commit comments

Comments
 (0)