Skip to content

Commit d9707c8

Browse files
committed
Code cleanup for precompute transformation algorithm
1 parent 3ef1af0 commit d9707c8

File tree

1 file changed

+53
-124
lines changed

1 file changed

+53
-124
lines changed

src/index_notation/transformations.cpp

Lines changed: 53 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -358,74 +358,6 @@ IndexStmt Precompute::apply(IndexStmt stmt, std::string* reason) const {
358358
return returnStmt;
359359
}
360360

361-
Forall buildProducer(int indexOriginal,
362-
std::map<IndexVar, IndexVar>& substitutions,
363-
TensorVar& ws,
364-
const std::vector<IndexVar>& i_vars,
365-
const std::vector<IndexVar>& iw_vars,
366-
const IndexExpr& e, vector<IndexVar> forallIndexVars) {
367-
368-
int index = i_vars.size() - indexOriginal - 1;
369-
substitutions[i_vars[index]] = iw_vars[index];
370-
371-
if (indexOriginal == 0) {
372-
auto assignment = ws(iw_vars) = replace(e, substitutions);
373-
374-
if (!assignment.getReductionVars().empty()) {
375-
assignment = Assignment(assignment.getLhs(), assignment.getRhs(), Add());
376-
} else {
377-
assignment = Assignment(assignment.getLhs(), assignment.getRhs());
378-
}
379-
380-
auto assignmentIndexVars = assignment.getIndexVars();
381-
//auto concreteStmt = stmt.concretizeScheduled(provGraph, forallIndexVarList);
382-
IndexStmt concreteStmt = assignment;
383-
for (auto &i : util::reverse(forallIndexVars)) {
384-
if (provGraph.isFullyDerived(i) &&
385-
std::find(assignmentIndexVars.begin(), assignmentIndexVars.end(), i) != assignmentIndexVars.end()) {
386-
concreteStmt = forall(i, concreteStmt);
387-
} else if (provGraph.isFullyDerived(i) && substitutions.find(i) != substitutions.end()
388-
&& std::find(assignmentIndexVars.begin(), assignmentIndexVars.end(), substitutions[i]) != assignmentIndexVars.end()) {
389-
concreteStmt = forall(substitutions[i], concreteStmt);
390-
} else {
391-
for (auto& underivedI : provGraph.getUnderivedAncestors(i)) {
392-
if (std::find(assignmentIndexVars.begin(), assignmentIndexVars.end(), underivedI) != assignmentIndexVars.end()) {
393-
concreteStmt = forall(i, concreteStmt);
394-
} else if (substitutions.find(underivedI) != substitutions.end()
395-
&& std::find(assignmentIndexVars.begin(), assignmentIndexVars.end(), substitutions[underivedI]) != assignmentIndexVars.end()) {
396-
concreteStmt = forall(substitutions[i], concreteStmt);
397-
}
398-
}
399-
}
400-
}
401-
402-
taco_iassert(isa<Forall>(concreteStmt)) << "Transformed producer statement does not begin with a Forall";
403-
return to<Forall>(concreteStmt);
404-
}
405-
return buildProducer(indexOriginal - 1, substitutions, ws, i_vars,
406-
iw_vars, e, forallIndexVars);
407-
}
408-
409-
Forall buildConsumer(Assignment assignment, vector<IndexVar> forallIndexVars) {
410-
auto assignmentIndexVars = assignment.getIndexVars();
411-
IndexStmt concreteStmt = assignment;
412-
for (auto &i : util::reverse(forallIndexVars)) {
413-
if (provGraph.isFullyDerived(i) && std::find(assignmentIndexVars.begin(), assignmentIndexVars.end(), i) != assignmentIndexVars.end()) {
414-
concreteStmt = forall(i, concreteStmt);
415-
} else {
416-
for (auto& underivedI : provGraph.getUnderivedAncestors(i)) {
417-
if (std::find(assignmentIndexVars.begin(), assignmentIndexVars.end(), underivedI) != assignmentIndexVars.end()) {
418-
concreteStmt = forall(i, concreteStmt);
419-
}
420-
}
421-
}
422-
}
423-
//auto concreteStmt = assignment.concretizeScheduled(provGraph, forallIndexVarList);
424-
//cout << "Build consumer stmt: " << concreteStmt << endl;
425-
taco_iassert(isa<Forall>(concreteStmt)) << "Transformed producer statement does not begin with a Forall";
426-
return to<Forall>(concreteStmt);
427-
}
428-
429361
bool containsIndexVarScheduled(vector<IndexVar> indexVars,
430362
IndexVar indexVar) {
431363
bool contains = false;
@@ -451,72 +383,69 @@ IndexStmt Precompute::apply(IndexStmt stmt, std::string* reason) const {
451383
Forall foralli(node);
452384
std::vector<IndexVar> i_vars = precompute.getIVars();
453385

454-
//if (foralli.getIndexVar() == i_vars[0]) {
455-
vector<IndexVar> forallIndexVars;
456-
match(foralli,
457-
function<void(const ForallNode*)>([&](const ForallNode* op) {
458-
forallIndexVars.push_back(op->indexVar);
459-
})
460-
);
386+
vector<IndexVar> forallIndexVars;
387+
match(foralli,
388+
function<void(const ForallNode*)>([&](const ForallNode* op) {
389+
forallIndexVars.push_back(op->indexVar);
390+
})
391+
);
461392

462-
IndexStmt s = foralli.getStmt();
463-
TensorVar ws = precompute.getWorkspace();
464-
IndexExpr e = precompute.getExpr();
465-
std::vector<IndexVar> iw_vars = precompute.getIWVars();
393+
IndexStmt s = foralli.getStmt();
394+
TensorVar ws = precompute.getWorkspace();
395+
IndexExpr e = precompute.getExpr();
396+
std::vector<IndexVar> iw_vars = precompute.getIWVars();
466397

467-
map<IndexVar, IndexVar> substitutions;
468-
taco_iassert(i_vars.size() == iw_vars.size()) << "i_vars and iw_vars lists must be the same size";
398+
map<IndexVar, IndexVar> substitutions;
399+
taco_iassert(i_vars.size() == iw_vars.size()) << "i_vars and iw_vars lists must be the same size";
469400

470-
for (int index = 0; index < (int)i_vars.size(); index++) {
471-
substitutions[i_vars[index]] = iw_vars[index];
472-
}
401+
for (int index = 0; index < (int)i_vars.size(); index++) {
402+
substitutions[i_vars[index]] = iw_vars[index];
403+
}
404+
405+
// Build consumer by replacing with temporary (in replacedStmt)
406+
IndexStmt replacedStmt = replace(s, {{e, ws(i_vars) }});
407+
if (replacedStmt != s) {
408+
// Then modify the replacedStmt to have the correct foralls
409+
// by concretizing the consumer assignment
410+
411+
auto consumerAssignment = getConsumerAssignment(replacedStmt, ws);
412+
auto consumerIndexVars = consumerAssignment.getIndexVars();
413+
414+
auto producerAssignment = getProducerAssignment(ws, i_vars, iw_vars, e, substitutions);
415+
auto producerIndexVars = producerAssignment.getIndexVars();
416+
417+
vector<IndexVar> producerForallIndexVars;
418+
vector<IndexVar> consumerForallIndexVars;
419+
vector<IndexVar> outerForallIndexVars;
473420

474-
// Build consumer by replacing with temporary (in replacedStmt)
475-
IndexStmt replacedStmt = replace(s, {{e, ws(i_vars) }});
476-
if (replacedStmt != s) {
477-
// Then modify the replacedStmt to have the correct foralls
478-
// by concretizing the consumer assignment
479-
480-
auto consumerAssignment = getConsumerAssignment(replacedStmt, ws);
481-
auto consumerIndexVars = consumerAssignment.getIndexVars();
482-
483-
auto producerAssignment = getProducerAssignment(ws, i_vars, iw_vars, e, substitutions);
484-
auto producerIndexVars = producerAssignment.getIndexVars();
485-
486-
// IndexStmt consumer = buildConsumer(consumerAssignment, forallIndexVars);
487-
//
488-
// // Buld producer by concretizing the producer assignment
489-
// std::map<IndexVar, IndexVar> substitutions;
490-
// IndexStmt producer = buildProducer(i_vars.size() - 1, substitutions, ws, i_vars, iw_vars, e, forallIndexVars);
491-
vector<IndexVar> producerForallIndexVars;
492-
vector<IndexVar> consumerForallIndexVars;
493-
vector<IndexVar> outerForallIndexVars;
494-
for (auto &i : util::reverse(forallIndexVars)) {
495-
if (containsIndexVarScheduled(i_vars, i)) {
496-
producerForallIndexVars.push_back(substitutions[i]);
421+
bool stopForallDistribution = false;
422+
for (auto &i : util::reverse(forallIndexVars)) {
423+
if (!stopForallDistribution && containsIndexVarScheduled(i_vars, i)) {
424+
producerForallIndexVars.push_back(substitutions[i]);
425+
consumerForallIndexVars.push_back(i);
426+
} else {
427+
auto consumerContains = containsIndexVarScheduled(consumerIndexVars, i);
428+
auto producerContains = containsIndexVarScheduled(producerIndexVars, i);
429+
if (stopForallDistribution || (producerContains && consumerContains)) {
430+
outerForallIndexVars.push_back(i);
431+
stopForallDistribution = true;
432+
} else if (!stopForallDistribution && consumerContains) {
497433
consumerForallIndexVars.push_back(i);
498-
} else {
499-
auto consumerContains = containsIndexVarScheduled(consumerIndexVars, i);
500-
auto producerContains = containsIndexVarScheduled(producerIndexVars, i);
501-
if (producerContains && consumerContains) {
502-
outerForallIndexVars.push_back(i);
503-
} else if (consumerContains) {
504-
consumerForallIndexVars.push_back(i);
505-
} else if (producerContains) {
506-
producerForallIndexVars.push_back(i);
507-
}
434+
} else if (!stopForallDistribution && producerContains) {
435+
producerForallIndexVars.push_back(i);
508436
}
509437
}
438+
}
510439

511-
IndexStmt consumer = generateForalls(consumerAssignment, consumerForallIndexVars);
440+
IndexStmt consumer = generateForalls(consumerAssignment, consumerForallIndexVars);
512441

513-
IndexStmt producer = generateForalls(producerAssignment, producerForallIndexVars);
514-
Where where(consumer, producer);
442+
IndexStmt producer = generateForalls(producerAssignment, producerForallIndexVars);
443+
Where where(consumer, producer);
515444

516-
stmt = generateForalls(where, outerForallIndexVars);
517-
return;
518-
}
519-
IndexNotationRewriter::visit(node);
445+
stmt = generateForalls(where, outerForallIndexVars);
446+
return;
447+
}
448+
IndexNotationRewriter::visit(node);
520449
}
521450
};
522451

0 commit comments

Comments
 (0)