@@ -1120,13 +1120,66 @@ bool ProvenanceGraph::isAvailable(IndexVar indexVar, std::set<IndexVar> defined)
11201120
11211121bool ProvenanceGraph::isRecoverable (taco::IndexVar indexVar, std::set<taco::IndexVar> defined ) const {
11221122 // all children are either defined or recoverable from their children
1123- // precompute relations are always recoverable since their children never appear in the same loop
1124- if (!(childRelMap.count (indexVar) && childRelMap.at (indexVar).getRelType () == IndexVarRelType::PRECOMPUTE)) {
1123+ // This checks the definedVars list to determine where in the statement the variables are trying to be
1124+ // recovered from ( either on the producer or consumer side of a where stmt or not in a where stmt)
1125+ bool producer = false ;
1126+ bool consumer = false ;
1127+ for (auto & def : defined ) {
1128+ if (childRelMap.count (def) && childRelMap.at (def).getRelType () == IndexVarRelType::PRECOMPUTE) {
1129+ consumer = true ;
1130+ }
1131+ if (parentRelMap.count (def) && parentRelMap.at (def).getRelType () == IndexVarRelType::PRECOMPUTE) {
1132+ producer = true ;
1133+ }
1134+ }
1135+ if (producer) {
1136+ return isRecoverableProducer (indexVar, defined );
1137+ }
1138+ else if (consumer) {
1139+ return isRecoverableConsumer (indexVar, defined );
1140+ } else {
1141+ return isRecoverableFull (indexVar, defined );
1142+ }
1143+ }
1144+
1145+ bool ProvenanceGraph::isRecoverableFull (taco::IndexVar indexVar, std::set<taco::IndexVar> defined ) const {
1146+ // all children are either defined or recoverable from their children
1147+ // precompute relations are treated as normal relations
11251148 for (const IndexVar& child : getChildren (indexVar)) {
1126- if (!defined .count (child) && (isFullyDerived (child) || !isRecoverable (child, defined ))) {
1149+ if (!defined .count (child) && (isFullyDerived (child) || !isRecoverableFull (child, defined ))) {
1150+ return false ;
1151+ }
1152+ }
1153+ return true ;
1154+ }
1155+
1156+ bool ProvenanceGraph::isRecoverableConsumer (taco::IndexVar indexVar, std::set<taco::IndexVar> defined ) const {
1157+ vector<IndexVar> childPrecompute;
1158+ if (childRelMap.count (indexVar) && childRelMap.at (indexVar).getRelType () == IndexVarRelType::PRECOMPUTE) {
1159+ childPrecompute = childrenMap.at (indexVar);
1160+ }
1161+ for (const IndexVar& child : getChildren (indexVar)) {
1162+ if (!childPrecompute.empty () && child == childPrecompute[0 ]) continue ;
1163+ if (!defined .count (child) && (isFullyDerived (child) ||
1164+ (childRelMap.count (child) && childRelMap.at (child).getRelType () == IndexVarRelType::PRECOMPUTE) || !isRecoverableConsumer (child, defined ))) {
1165+ return false ;
1166+ }
1167+ }
1168+ return true ;
1169+ }
1170+
1171+ bool ProvenanceGraph::isRecoverableProducer (taco::IndexVar indexVar, std::set<taco::IndexVar> defined ) const {
1172+ vector<IndexVar> childPrecompute;
1173+ for (const IndexVar& child : getChildren (indexVar)) {
1174+ if (childRelMap.count (child) && childRelMap.at (child).getRelType () == IndexVarRelType::PRECOMPUTE) {
1175+ auto precomputeChild = childrenMap.at (child)[0 ];
1176+ if (!defined .count (precomputeChild) && (isFullyDerived (precomputeChild) || !isRecoverableProducer (precomputeChild, defined ))) {
11271177 return false ;
11281178 }
11291179 }
1180+ else if (!defined .count (child) && (isFullyDerived (child) || !isRecoverableProducer (child, defined ))) {
1181+ return false ;
1182+ }
11301183 }
11311184 return true ;
11321185}
@@ -1289,6 +1342,12 @@ bool ProvenanceGraph::hasExactBound(IndexVar indexVar) const {
12891342 {
12901343 return rel.getNode <BoundRelNode>()->getBoundType () == BoundType::MaxExact;
12911344 }
1345+ // else if (rel.getRelType() == SPLIT)
1346+ // {
1347+ // return rel.getNode<SplitRelNode>()->getInnerVar() == indexVar;
1348+ // } else if (rel.getRelType() == PRECOMPUTE) {
1349+ // return hasExactBound(rel.getNode<PrecomputeRelNode>()->getParentVar());
1350+ // }
12921351 // TODO: include non-irregular variables
12931352 return false ;
12941353}
0 commit comments