@@ -358,74 +358,6 @@ IndexStmt Precompute::apply(IndexStmt stmt, std::string* reason) const {
358358 return returnStmt;
359359 }
360360
361- Forall buildProducer (int indexOriginal,
362- std::map<IndexVar, IndexVar>& substitutions,
363- TensorVar& ws,
364- const std::vector<IndexVar>& i_vars,
365- const std::vector<IndexVar>& iw_vars,
366- const IndexExpr& e, vector<IndexVar> forallIndexVars) {
367-
368- int index = i_vars.size () - indexOriginal - 1 ;
369- substitutions[i_vars[index]] = iw_vars[index];
370-
371- if (indexOriginal == 0 ) {
372- auto assignment = ws (iw_vars) = replace (e, substitutions);
373-
374- if (!assignment.getReductionVars ().empty ()) {
375- assignment = Assignment (assignment.getLhs (), assignment.getRhs (), Add ());
376- } else {
377- assignment = Assignment (assignment.getLhs (), assignment.getRhs ());
378- }
379-
380- auto assignmentIndexVars = assignment.getIndexVars ();
381- // auto concreteStmt = stmt.concretizeScheduled(provGraph, forallIndexVarList);
382- IndexStmt concreteStmt = assignment;
383- for (auto &i : util::reverse (forallIndexVars)) {
384- if (provGraph.isFullyDerived (i) &&
385- std::find (assignmentIndexVars.begin (), assignmentIndexVars.end (), i) != assignmentIndexVars.end ()) {
386- concreteStmt = forall (i, concreteStmt);
387- } else if (provGraph.isFullyDerived (i) && substitutions.find (i) != substitutions.end ()
388- && std::find (assignmentIndexVars.begin (), assignmentIndexVars.end (), substitutions[i]) != assignmentIndexVars.end ()) {
389- concreteStmt = forall (substitutions[i], concreteStmt);
390- } else {
391- for (auto & underivedI : provGraph.getUnderivedAncestors (i)) {
392- if (std::find (assignmentIndexVars.begin (), assignmentIndexVars.end (), underivedI) != assignmentIndexVars.end ()) {
393- concreteStmt = forall (i, concreteStmt);
394- } else if (substitutions.find (underivedI) != substitutions.end ()
395- && std::find (assignmentIndexVars.begin (), assignmentIndexVars.end (), substitutions[underivedI]) != assignmentIndexVars.end ()) {
396- concreteStmt = forall (substitutions[i], concreteStmt);
397- }
398- }
399- }
400- }
401-
402- taco_iassert (isa<Forall>(concreteStmt)) << " Transformed producer statement does not begin with a Forall" ;
403- return to<Forall>(concreteStmt);
404- }
405- return buildProducer (indexOriginal - 1 , substitutions, ws, i_vars,
406- iw_vars, e, forallIndexVars);
407- }
408-
409- Forall buildConsumer (Assignment assignment, vector<IndexVar> forallIndexVars) {
410- auto assignmentIndexVars = assignment.getIndexVars ();
411- IndexStmt concreteStmt = assignment;
412- for (auto &i : util::reverse (forallIndexVars)) {
413- if (provGraph.isFullyDerived (i) && std::find (assignmentIndexVars.begin (), assignmentIndexVars.end (), i) != assignmentIndexVars.end ()) {
414- concreteStmt = forall (i, concreteStmt);
415- } else {
416- for (auto & underivedI : provGraph.getUnderivedAncestors (i)) {
417- if (std::find (assignmentIndexVars.begin (), assignmentIndexVars.end (), underivedI) != assignmentIndexVars.end ()) {
418- concreteStmt = forall (i, concreteStmt);
419- }
420- }
421- }
422- }
423- // auto concreteStmt = assignment.concretizeScheduled(provGraph, forallIndexVarList);
424- // cout << "Build consumer stmt: " << concreteStmt << endl;
425- taco_iassert (isa<Forall>(concreteStmt)) << " Transformed producer statement does not begin with a Forall" ;
426- return to<Forall>(concreteStmt);
427- }
428-
429361 bool containsIndexVarScheduled (vector<IndexVar> indexVars,
430362 IndexVar indexVar) {
431363 bool contains = false ;
@@ -451,72 +383,69 @@ IndexStmt Precompute::apply(IndexStmt stmt, std::string* reason) const {
451383 Forall foralli (node);
452384 std::vector<IndexVar> i_vars = precompute.getIVars ();
453385
454- // if (foralli.getIndexVar() == i_vars[0]) {
455- vector<IndexVar> forallIndexVars;
456- match (foralli,
457- function<void (const ForallNode*)>([&](const ForallNode* op) {
458- forallIndexVars.push_back (op->indexVar );
459- })
460- );
386+ vector<IndexVar> forallIndexVars;
387+ match (foralli,
388+ function<void (const ForallNode*)>([&](const ForallNode* op) {
389+ forallIndexVars.push_back (op->indexVar );
390+ })
391+ );
461392
462- IndexStmt s = foralli.getStmt ();
463- TensorVar ws = precompute.getWorkspace ();
464- IndexExpr e = precompute.getExpr ();
465- std::vector<IndexVar> iw_vars = precompute.getIWVars ();
393+ IndexStmt s = foralli.getStmt ();
394+ TensorVar ws = precompute.getWorkspace ();
395+ IndexExpr e = precompute.getExpr ();
396+ std::vector<IndexVar> iw_vars = precompute.getIWVars ();
466397
467- map<IndexVar, IndexVar> substitutions;
468- taco_iassert (i_vars.size () == iw_vars.size ()) << " i_vars and iw_vars lists must be the same size" ;
398+ map<IndexVar, IndexVar> substitutions;
399+ taco_iassert (i_vars.size () == iw_vars.size ()) << " i_vars and iw_vars lists must be the same size" ;
469400
470- for (int index = 0 ; index < (int )i_vars.size (); index++) {
471- substitutions[i_vars[index]] = iw_vars[index];
472- }
401+ for (int index = 0 ; index < (int )i_vars.size (); index++) {
402+ substitutions[i_vars[index]] = iw_vars[index];
403+ }
404+
405+ // Build consumer by replacing with temporary (in replacedStmt)
406+ IndexStmt replacedStmt = replace (s, {{e, ws (i_vars) }});
407+ if (replacedStmt != s) {
408+ // Then modify the replacedStmt to have the correct foralls
409+ // by concretizing the consumer assignment
410+
411+ auto consumerAssignment = getConsumerAssignment (replacedStmt, ws);
412+ auto consumerIndexVars = consumerAssignment.getIndexVars ();
413+
414+ auto producerAssignment = getProducerAssignment (ws, i_vars, iw_vars, e, substitutions);
415+ auto producerIndexVars = producerAssignment.getIndexVars ();
416+
417+ vector<IndexVar> producerForallIndexVars;
418+ vector<IndexVar> consumerForallIndexVars;
419+ vector<IndexVar> outerForallIndexVars;
473420
474- // Build consumer by replacing with temporary (in replacedStmt)
475- IndexStmt replacedStmt = replace (s, {{e, ws (i_vars) }});
476- if (replacedStmt != s) {
477- // Then modify the replacedStmt to have the correct foralls
478- // by concretizing the consumer assignment
479-
480- auto consumerAssignment = getConsumerAssignment (replacedStmt, ws);
481- auto consumerIndexVars = consumerAssignment.getIndexVars ();
482-
483- auto producerAssignment = getProducerAssignment (ws, i_vars, iw_vars, e, substitutions);
484- auto producerIndexVars = producerAssignment.getIndexVars ();
485-
486- // IndexStmt consumer = buildConsumer(consumerAssignment, forallIndexVars);
487- //
488- // // Buld producer by concretizing the producer assignment
489- // std::map<IndexVar, IndexVar> substitutions;
490- // IndexStmt producer = buildProducer(i_vars.size() - 1, substitutions, ws, i_vars, iw_vars, e, forallIndexVars);
491- vector<IndexVar> producerForallIndexVars;
492- vector<IndexVar> consumerForallIndexVars;
493- vector<IndexVar> outerForallIndexVars;
494- for (auto &i : util::reverse (forallIndexVars)) {
495- if (containsIndexVarScheduled (i_vars, i)) {
496- producerForallIndexVars.push_back (substitutions[i]);
421+ bool stopForallDistribution = false ;
422+ for (auto &i : util::reverse (forallIndexVars)) {
423+ if (!stopForallDistribution && containsIndexVarScheduled (i_vars, i)) {
424+ producerForallIndexVars.push_back (substitutions[i]);
425+ consumerForallIndexVars.push_back (i);
426+ } else {
427+ auto consumerContains = containsIndexVarScheduled (consumerIndexVars, i);
428+ auto producerContains = containsIndexVarScheduled (producerIndexVars, i);
429+ if (stopForallDistribution || (producerContains && consumerContains)) {
430+ outerForallIndexVars.push_back (i);
431+ stopForallDistribution = true ;
432+ } else if (!stopForallDistribution && consumerContains) {
497433 consumerForallIndexVars.push_back (i);
498- } else {
499- auto consumerContains = containsIndexVarScheduled (consumerIndexVars, i);
500- auto producerContains = containsIndexVarScheduled (producerIndexVars, i);
501- if (producerContains && consumerContains) {
502- outerForallIndexVars.push_back (i);
503- } else if (consumerContains) {
504- consumerForallIndexVars.push_back (i);
505- } else if (producerContains) {
506- producerForallIndexVars.push_back (i);
507- }
434+ } else if (!stopForallDistribution && producerContains) {
435+ producerForallIndexVars.push_back (i);
508436 }
509437 }
438+ }
510439
511- IndexStmt consumer = generateForalls (consumerAssignment, consumerForallIndexVars);
440+ IndexStmt consumer = generateForalls (consumerAssignment, consumerForallIndexVars);
512441
513- IndexStmt producer = generateForalls (producerAssignment, producerForallIndexVars);
514- Where where (consumer, producer);
442+ IndexStmt producer = generateForalls (producerAssignment, producerForallIndexVars);
443+ Where where (consumer, producer);
515444
516- stmt = generateForalls (where, outerForallIndexVars);
517- return ;
518- }
519- IndexNotationRewriter::visit (node);
445+ stmt = generateForalls (where, outerForallIndexVars);
446+ return ;
447+ }
448+ IndexNotationRewriter::visit (node);
520449 }
521450 };
522451
0 commit comments