@@ -1120,8 +1120,39 @@ bool ProvenanceGraph::isAvailable(IndexVar indexVar, std::set<IndexVar> defined)
11201120
11211121bool ProvenanceGraph::isRecoverable (taco::IndexVar indexVar, std::set<taco::IndexVar> defined ) const {
11221122 // all children are either defined or recoverable from their children
1123+ // This checks the definedVars list to determine where in the statement the variables are trying to be
1124+ // recovered from ( either on the producer or consumer side of a where stmt or not in a where stmt)
1125+ vector<IndexVar> producers;
1126+ vector<IndexVar> consumers;
1127+ for (auto & def : defined ) {
1128+ if (childRelMap.count (def) && childRelMap.at (def).getRelType () == IndexVarRelType::PRECOMPUTE) {
1129+ consumers.push_back (def);
1130+ }
1131+ if (parentRelMap.count (def) && parentRelMap.at (def).getRelType () == IndexVarRelType::PRECOMPUTE) {
1132+ producers.push_back (def);
1133+ }
1134+ }
1135+
1136+ return isRecoverablePrecompute (indexVar, defined , producers, consumers);
1137+ }
1138+
1139+ bool ProvenanceGraph::isRecoverablePrecompute (taco::IndexVar indexVar, std::set<taco::IndexVar> defined ,
1140+ vector<IndexVar> producers, vector<IndexVar> consumers) const {
1141+ vector<IndexVar> childPrecompute;
1142+ if (std::find (consumers.begin (), consumers.end (), indexVar) != consumers.end ()) {
1143+ return true ;
1144+ }
1145+ if (!producers.empty () && (childRelMap.count (indexVar) &&
1146+ childRelMap.at (indexVar).getRelType () == IndexVarRelType::PRECOMPUTE)) {
1147+ auto precomputeChild = getChildren (indexVar)[0 ];
1148+ if (std::find (producers.begin (), producers.end (), precomputeChild) != producers.end ()) {
1149+ return true ;
1150+ }
1151+ return isRecoverablePrecompute (precomputeChild, defined , producers, consumers);
1152+ }
11231153 for (const IndexVar& child : getChildren (indexVar)) {
1124- if (!defined .count (child) && (isFullyDerived (child) || !isRecoverable (child, defined ))) {
1154+ if (!defined .count (child) && (isFullyDerived (child) ||
1155+ !isRecoverablePrecompute (child, defined , producers, consumers))) {
11251156 return false ;
11261157 }
11271158 }
0 commit comments