@@ -1122,62 +1122,35 @@ bool ProvenanceGraph::isRecoverable(taco::IndexVar indexVar, std::set<taco::Inde
11221122 // all children are either defined or recoverable from their children
11231123 // This checks the definedVars list to determine where in the statement the variables are trying to be
11241124 // recovered from ( either on the producer or consumer side of a where stmt or not in a where stmt)
1125- bool producer = false ;
1126- bool consumer = false ;
1125+ vector<IndexVar> producers ;
1126+ vector<IndexVar> consumers ;
11271127 for (auto & def : defined ) {
11281128 if (childRelMap.count (def) && childRelMap.at (def).getRelType () == IndexVarRelType::PRECOMPUTE) {
1129- consumer = true ;
1129+ consumers. push_back (def) ;
11301130 }
11311131 if (parentRelMap.count (def) && parentRelMap.at (def).getRelType () == IndexVarRelType::PRECOMPUTE) {
1132- producer = true ;
1132+ producers. push_back (def) ;
11331133 }
11341134 }
1135- if (producer) {
1136- return isRecoverableProducer (indexVar, defined );
1137- }
1138- else if (consumer) {
1139- return isRecoverableConsumer (indexVar, defined );
1140- } else {
1141- return isRecoverableFull (indexVar, defined );
1142- }
1143- }
11441135
1145- bool ProvenanceGraph::isRecoverableFull (taco::IndexVar indexVar, std::set<taco::IndexVar> defined ) const {
1146- // all children are either defined or recoverable from their children
1147- // precompute relations are treated as normal relations
1148- for (const IndexVar& child : getChildren (indexVar)) {
1149- if (!defined .count (child) && (isFullyDerived (child) || !isRecoverableFull (child, defined ))) {
1150- return false ;
1151- }
1152- }
1153- return true ;
1136+ return isRecoverablePrecompute (indexVar, defined , producers, consumers);
11541137}
11551138
1156- bool ProvenanceGraph::isRecoverableConsumer (taco::IndexVar indexVar, std::set<taco::IndexVar> defined ) const {
1139+ bool ProvenanceGraph::isRecoverablePrecompute (taco::IndexVar indexVar, std::set<taco::IndexVar> defined , vector<IndexVar> producers, vector<IndexVar> consumers ) const {
11571140 vector<IndexVar> childPrecompute;
1158- if (childRelMap. count (indexVar) && childRelMap. at ( indexVar). getRelType () == IndexVarRelType::PRECOMPUTE ) {
1159- childPrecompute = childrenMap. at (indexVar) ;
1141+ if (std::find (consumers. begin (), consumers. end (), indexVar) != consumers. end () ) {
1142+ return true ;
11601143 }
1161- for (const IndexVar& child : getChildren (indexVar)) {
1162- if (!childPrecompute.empty () && child == childPrecompute[0 ]) continue ;
1163- if (!defined .count (child) && (isFullyDerived (child) ||
1164- (childRelMap.count (child) && childRelMap.at (child).getRelType () == IndexVarRelType::PRECOMPUTE) || !isRecoverableConsumer (child, defined ))) {
1165- return false ;
1144+ if (!producers.empty () && (childRelMap.count (indexVar) && childRelMap.at (indexVar).getRelType () == IndexVarRelType::PRECOMPUTE)) {
1145+ auto precomputeChild = getChildren (indexVar)[0 ];
1146+ if (std::find (producers.begin (), producers.end (), precomputeChild) != producers.end ()) {
1147+ return true ;
11661148 }
1149+ return isRecoverablePrecompute (precomputeChild, defined , producers, consumers);
11671150 }
1168- return true ;
1169- }
1170-
1171- bool ProvenanceGraph::isRecoverableProducer (taco::IndexVar indexVar, std::set<taco::IndexVar> defined ) const {
1172- vector<IndexVar> childPrecompute;
11731151 for (const IndexVar& child : getChildren (indexVar)) {
1174- if (childRelMap.count (child) && childRelMap.at (child).getRelType () == IndexVarRelType::PRECOMPUTE) {
1175- auto precomputeChild = childrenMap.at (child)[0 ];
1176- if (!defined .count (precomputeChild) && (isFullyDerived (precomputeChild) || !isRecoverableProducer (precomputeChild, defined ))) {
1177- return false ;
1178- }
1179- }
1180- else if (!defined .count (child) && (isFullyDerived (child) || !isRecoverableProducer (child, defined ))) {
1152+ if (!defined .count (child) && (isFullyDerived (child)
1153+ || !isRecoverablePrecompute (child, defined , producers, consumers))) {
11811154 return false ;
11821155 }
11831156 }
0 commit comments