Skip to content

Commit 6e41694

Browse files
nirvikbaruahweiya711
authored andcommitted
Intial commit for new fix
1 parent 88fcb6c commit 6e41694

File tree

6 files changed

+109
-5
lines changed

6 files changed

+109
-5
lines changed

.vscode/settings.json

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
{
2+
"files.associations": {
3+
"array": "cpp",
4+
"atomic": "cpp",
5+
"bit": "cpp",
6+
"*.tcc": "cpp",
7+
"cctype": "cpp",
8+
"chrono": "cpp",
9+
"clocale": "cpp",
10+
"cmath": "cpp",
11+
"compare": "cpp",
12+
"complex": "cpp",
13+
"concepts": "cpp",
14+
"condition_variable": "cpp",
15+
"cstdarg": "cpp",
16+
"cstddef": "cpp",
17+
"cstdint": "cpp",
18+
"cstdio": "cpp",
19+
"cstdlib": "cpp",
20+
"cstring": "cpp",
21+
"ctime": "cpp",
22+
"cwchar": "cpp",
23+
"cwctype": "cpp",
24+
"deque": "cpp",
25+
"forward_list": "cpp",
26+
"list": "cpp",
27+
"map": "cpp",
28+
"set": "cpp",
29+
"unordered_map": "cpp",
30+
"unordered_set": "cpp",
31+
"vector": "cpp",
32+
"exception": "cpp",
33+
"algorithm": "cpp",
34+
"functional": "cpp",
35+
"iterator": "cpp",
36+
"memory": "cpp",
37+
"memory_resource": "cpp",
38+
"numeric": "cpp",
39+
"optional": "cpp",
40+
"random": "cpp",
41+
"ratio": "cpp",
42+
"string": "cpp",
43+
"string_view": "cpp",
44+
"system_error": "cpp",
45+
"tuple": "cpp",
46+
"type_traits": "cpp",
47+
"utility": "cpp",
48+
"fstream": "cpp",
49+
"initializer_list": "cpp",
50+
"iomanip": "cpp",
51+
"iosfwd": "cpp",
52+
"iostream": "cpp",
53+
"istream": "cpp",
54+
"limits": "cpp",
55+
"mutex": "cpp",
56+
"new": "cpp",
57+
"ostream": "cpp",
58+
"ranges": "cpp",
59+
"sstream": "cpp",
60+
"stdexcept": "cpp",
61+
"stop_token": "cpp",
62+
"streambuf": "cpp",
63+
"thread": "cpp",
64+
"typeindex": "cpp",
65+
"typeinfo": "cpp",
66+
"valarray": "cpp",
67+
"variant": "cpp"
68+
}
69+
}

src/index_notation/index_notation.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1514,6 +1514,14 @@ IndexStmt IndexStmt::precompute(IndexExpr expr, IndexVar i, IndexVar iw, TensorV
15141514
IndexStmt transformed = *this;
15151515
string reason;
15161516

1517+
if (i != iw) {
1518+
IndexVarRel rel = IndexVarRel(new PrecomputeRelNode(i, iw));
1519+
transformed = Transformation(AddSuchThatPredicates({rel})).apply(transformed, &reason);
1520+
if (!transformed.defined()) {
1521+
taco_uerror << reason;
1522+
}
1523+
}
1524+
15171525
transformed = Transformation(Precompute(expr, i, iw, workspace)).apply(transformed, &reason);
15181526
if (!transformed.defined()) {
15191527
taco_uerror << reason;
@@ -2358,6 +2366,7 @@ bool isConcreteNotation(IndexStmt stmt, std::string* reason) {
23582366
for (auto& var : op->indexVars) {
23592367
// non underived variables may appear in temporaries, but we don't check these
23602368
if (!boundVars.contains(var) && provGraph.isUnderived(var) && (provGraph.isFullyDerived(var) || !provGraph.isRecoverable(var, definedVars))) {
2369+
// cout << "Variable: " << var << " Statement: " << stmt << endl;
23612370
*reason = "all variables in concrete notation must be bound by a "
23622371
"forall statement";
23632372
isConcrete = false;

src/index_notation/provenance_graph.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1120,9 +1120,12 @@ bool ProvenanceGraph::isAvailable(IndexVar indexVar, std::set<IndexVar> defined)
11201120

11211121
bool ProvenanceGraph::isRecoverable(taco::IndexVar indexVar, std::set<taco::IndexVar> defined) const {
11221122
// all children are either defined or recoverable from their children
1123-
for (const IndexVar& child : getChildren(indexVar)) {
1124-
if (!defined.count(child) && (isFullyDerived(child) || !isRecoverable(child, defined))) {
1125-
return false;
1123+
// precompute relations are always recoverable since their children never appear in the same loop
1124+
if (!(childRelMap.at(indexVar).getRelType() == IndexVarRelType::PRECOMPUTE)) {
1125+
for (const IndexVar& child : getChildren(indexVar)) {
1126+
if (!defined.count(child) && (isFullyDerived(child) || !isRecoverable(child, defined))) {
1127+
return false;
1128+
}
11261129
}
11271130
}
11281131
return true;

src/lower/iterator.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,16 +521,28 @@ Iterators::Iterators(IndexStmt stmt, const map<TensorVar, Expr>& tensorVars)
521521
{
522522
ProvenanceGraph provGraph = ProvenanceGraph(stmt);
523523
set<IndexVar> underivedAdded;
524+
524525
// Create dimension iterators
525526
match(stmt,
526527
function<void(const ForallNode*, Matcher*)>([&](auto n, auto m) {
527528
content->modeIterators.insert({n->indexVar, Iterator(n->indexVar, !provGraph.hasCoordBounds(n->indexVar) && provGraph.isCoordVariable(n->indexVar))});
529+
cout << "Adding following index var to iterators: " << n->indexVar << " for statement (" << n->stmt << ")" << endl;
528530
for (const IndexVar& underived : provGraph.getUnderivedAncestors(n->indexVar)) {
529531
if (!underivedAdded.count(underived)) {
532+
cout << "Adding following underived ancestor to iterators: " << underived << endl;
530533
content->modeIterators.insert({underived, underived});
531534
underivedAdded.insert(underived);
532535
}
533536
}
537+
538+
// Insert all children of current index variable into iterators as well
539+
for (const IndexVar& child : provGraph.getChildren(n->indexVar)) {
540+
if (!underivedAdded.count(child)) {
541+
content->modeIterators.insert({child, child});
542+
underivedAdded.insert(child);
543+
}
544+
}
545+
534546
m->match(n->stmt);
535547
})
536548
);
@@ -553,6 +565,8 @@ Iterators::Iterators(IndexStmt stmt, const map<TensorVar, Expr>& tensorVars)
553565
for (auto& iterator : content->levelIterators) {
554566
content->modeAccesses.insert({iterator.second, iterator.first});
555567
}
568+
569+
// cout << "FINISHED ITERATORS BUILDING" << endl;
556570
}
557571

558572

@@ -662,6 +676,8 @@ ModeAccess Iterators::modeAccess(Iterator iterator) const
662676
Iterator Iterators::modeIterator(IndexVar indexVar) const
663677
{
664678
taco_iassert(content != nullptr);
679+
cout << "Searching for " << indexVar << " in "
680+
<< util::join(content->modeIterators) << endl;
665681
taco_iassert(util::contains(content->modeIterators, indexVar));
666682
return content->modeIterators.at(indexVar);
667683
}

src/lower/merge_lattice.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -613,7 +613,9 @@ MergeLattice MergeLattice::make(Forall forall, Iterators iterators, ProvenanceGr
613613

614614
vector<IndexVar> underivedAncestors = provGraph.getUnderivedAncestors(indexVar);
615615
for (auto ancestor : underivedAncestors) {
616+
// cout << "Is recoverable from merge lattice: " << ancestor << endl;
616617
if(!provGraph.isRecoverable(ancestor, definedIndexVars)) {
618+
// cout << "Is not recoverable for ancestor " << ancestor << endl;
617619
return MergeLattice({MergePoint({iterators.modeIterator(indexVar)}, {}, {})});
618620
}
619621
}

test/tests-scheduling.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ TEST(scheduling, lowerSparseMulSparse) {
276276
// codegen->compile(compute, true);
277277
}
278278

279-
TEST(scheduling, precomputeIndependentIndexVars) {
279+
TEST(scheduling, NIRVIK_TEST) {
280280
Tensor<double> A("A", {16}, Format{Dense});
281281
Tensor<double> B("B", {16}, Format{Dense});
282282
Tensor<double> C("C", {16}, Format{Dense});
@@ -312,7 +312,7 @@ TEST(scheduling, precomputeIndependentIndexVars) {
312312
ASSERT_TENSOR_EQ(A, expected);
313313
}
314314

315-
TEST(scheduling, precomputeIndependentIndexVarsSplit) {
315+
TEST(scheduling, FAILED_TEST) {
316316
Tensor<double> A("A", {16}, Format{Dense});
317317
Tensor<double> B("B", {16}, Format{Dense});
318318
Tensor<double> C("C", {16}, Format{Dense});
@@ -347,6 +347,11 @@ TEST(scheduling, precomputeIndependentIndexVarsSplit) {
347347
expected.assemble();
348348
expected.compute();
349349

350+
ir::IRPrinter irp = ir::IRPrinter(cout);
351+
ir::Stmt compute = lower(stmt, "compute", false, true);
352+
cout << "Imperative IR" << endl;
353+
irp.print(compute);
354+
350355
ASSERT_TENSOR_EQ(A, expected);
351356
}
352357

0 commit comments

Comments
 (0)