Skip to content

Commit e72fd96

Browse files
committed
Fix some precopute transformation algorithm bugs that arose
1 parent 97edc84 commit e72fd96

File tree

2 files changed

+180
-54
lines changed

2 files changed

+180
-54
lines changed

src/index_notation/transformations.cpp

Lines changed: 53 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -383,67 +383,76 @@ IndexStmt Precompute::apply(IndexStmt stmt, std::string* reason) const {
383383
Forall foralli(node);
384384
std::vector<IndexVar> i_vars = precompute.getIVars();
385385

386-
vector<IndexVar> forallIndexVars;
386+
bool containsWhere = false;
387387
match(foralli,
388-
function<void(const ForallNode*)>([&](const ForallNode* op) {
389-
forallIndexVars.push_back(op->indexVar);
388+
function<void(const WhereNode*)>([&](const WhereNode* op) {
389+
containsWhere = true;
390390
})
391391
);
392392

393-
IndexStmt s = foralli.getStmt();
394-
TensorVar ws = precompute.getWorkspace();
395-
IndexExpr e = precompute.getExpr();
396-
std::vector<IndexVar> iw_vars = precompute.getIWVars();
393+
if (!containsWhere) {
394+
vector<IndexVar> forallIndexVars;
395+
match(foralli,
396+
function<void(const ForallNode*)>([&](const ForallNode* op) {
397+
forallIndexVars.push_back(op->indexVar);
398+
})
399+
);
397400

398-
map<IndexVar, IndexVar> substitutions;
399-
taco_iassert(i_vars.size() == iw_vars.size()) << "i_vars and iw_vars lists must be the same size";
401+
IndexStmt s = foralli.getStmt();
402+
TensorVar ws = precompute.getWorkspace();
403+
IndexExpr e = precompute.getExpr();
404+
std::vector<IndexVar> iw_vars = precompute.getIWVars();
400405

401-
for (int index = 0; index < (int)i_vars.size(); index++) {
402-
substitutions[i_vars[index]] = iw_vars[index];
403-
}
406+
map<IndexVar, IndexVar> substitutions;
407+
taco_iassert(i_vars.size() == iw_vars.size()) << "i_vars and iw_vars lists must be the same size";
404408

405-
// Build consumer by replacing with temporary (in replacedStmt)
406-
IndexStmt replacedStmt = replace(s, {{e, ws(i_vars) }});
407-
if (replacedStmt != s) {
408-
// Then modify the replacedStmt to have the correct foralls
409-
// by concretizing the consumer assignment
409+
for (int index = 0; index < (int)i_vars.size(); index++) {
410+
substitutions[i_vars[index]] = iw_vars[index];
411+
}
410412

411-
auto consumerAssignment = getConsumerAssignment(replacedStmt, ws);
412-
auto consumerIndexVars = consumerAssignment.getIndexVars();
413+
// Build consumer by replacing with temporary (in replacedStmt)
414+
IndexStmt replacedStmt = replace(s, {{e, ws(i_vars) }});
415+
if (replacedStmt != s) {
416+
// Then modify the replacedStmt to have the correct foralls
417+
// by concretizing the consumer assignment
413418

414-
auto producerAssignment = getProducerAssignment(ws, i_vars, iw_vars, e, substitutions);
415-
auto producerIndexVars = producerAssignment.getIndexVars();
419+
auto consumerAssignment = getConsumerAssignment(replacedStmt, ws);
420+
auto consumerIndexVars = consumerAssignment.getIndexVars();
416421

417-
vector<IndexVar> producerForallIndexVars;
418-
vector<IndexVar> consumerForallIndexVars;
419-
vector<IndexVar> outerForallIndexVars;
422+
auto producerAssignment = getProducerAssignment(ws, i_vars, iw_vars, e, substitutions);
423+
auto producerIndexVars = producerAssignment.getIndexVars();
420424

421-
bool stopForallDistribution = false;
422-
for (auto &i : util::reverse(forallIndexVars)) {
423-
if (!stopForallDistribution && containsIndexVarScheduled(i_vars, i)) {
424-
producerForallIndexVars.push_back(substitutions[i]);
425-
consumerForallIndexVars.push_back(i);
426-
} else {
427-
auto consumerContains = containsIndexVarScheduled(consumerIndexVars, i);
428-
auto producerContains = containsIndexVarScheduled(producerIndexVars, i);
429-
if (stopForallDistribution || (producerContains && consumerContains)) {
430-
outerForallIndexVars.push_back(i);
431-
stopForallDistribution = true;
432-
} else if (!stopForallDistribution && consumerContains) {
425+
vector<IndexVar> producerForallIndexVars;
426+
vector<IndexVar> consumerForallIndexVars;
427+
vector<IndexVar> outerForallIndexVars;
428+
429+
bool stopForallDistribution = false;
430+
for (auto &i : util::reverse(forallIndexVars)) {
431+
if (!stopForallDistribution && containsIndexVarScheduled(i_vars, i)) {
432+
producerForallIndexVars.push_back(substitutions[i]);
433433
consumerForallIndexVars.push_back(i);
434-
} else if (!stopForallDistribution && producerContains) {
435-
producerForallIndexVars.push_back(i);
434+
} else {
435+
auto consumerContains = containsIndexVarScheduled(consumerIndexVars, i);
436+
auto producerContains = containsIndexVarScheduled(producerIndexVars, i);
437+
if (stopForallDistribution || (producerContains && consumerContains)) {
438+
outerForallIndexVars.push_back(i);
439+
stopForallDistribution = true;
440+
} else if (!stopForallDistribution && consumerContains) {
441+
consumerForallIndexVars.push_back(i);
442+
} else if (!stopForallDistribution && producerContains) {
443+
producerForallIndexVars.push_back(i);
444+
}
436445
}
437446
}
438-
}
439447

440-
IndexStmt consumer = generateForalls(consumerAssignment, consumerForallIndexVars);
448+
IndexStmt consumer = generateForalls(consumerAssignment, consumerForallIndexVars);
441449

442-
IndexStmt producer = generateForalls(producerAssignment, producerForallIndexVars);
443-
Where where(consumer, producer);
450+
IndexStmt producer = generateForalls(producerAssignment, producerForallIndexVars);
451+
Where where(consumer, producer);
444452

445-
stmt = generateForalls(where, outerForallIndexVars);
446-
return;
453+
stmt = generateForalls(where, outerForallIndexVars);
454+
return;
455+
}
447456
}
448457
IndexNotationRewriter::visit(node);
449458
}

test/tests-workspaces.cpp

Lines changed: 127 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ TEST(workspaces, tile_vecElemMul_NoTail) {
4545
expected.compile();
4646
expected.assemble();
4747
expected.compute();
48-
ASSERT_TENSOR_EQ(A, expected);
48+
ASSERT_TENSOR_EQ(expected, A);
4949
}
5050

5151
TEST(workspaces, tile_vecElemMul_Tail1) {
@@ -83,7 +83,7 @@ TEST(workspaces, tile_vecElemMul_Tail1) {
8383
expected.compile();
8484
expected.assemble();
8585
expected.compute();
86-
ASSERT_TENSOR_EQ(A, expected);
86+
ASSERT_TENSOR_EQ(expected, A);
8787
}
8888

8989
TEST(workspaces, tile_vecElemMul_Tail2) {
@@ -121,7 +121,7 @@ TEST(workspaces, tile_vecElemMul_Tail2) {
121121
expected.compile();
122122
expected.assemble();
123123
expected.compute();
124-
ASSERT_TENSOR_EQ(A, expected);
124+
ASSERT_TENSOR_EQ(expected, A);
125125

126126
// ir::IRPrinter irp = ir::IRPrinter(cout);
127127
//
@@ -171,7 +171,7 @@ TEST(workspaces, tile_denseMatMul) {
171171
expected.compile();
172172
expected.assemble();
173173
expected.compute();
174-
ASSERT_TENSOR_EQ(A, expected);
174+
ASSERT_TENSOR_EQ(expected, A);
175175

176176
// ir::IRPrinter irp = ir::IRPrinter(cout);
177177
//
@@ -218,7 +218,7 @@ TEST(workspaces, precompute2D_add) {
218218
expected.compile();
219219
expected.assemble();
220220
expected.compute();
221-
ASSERT_TENSOR_EQ(A, expected);
221+
ASSERT_TENSOR_EQ(expected, A);
222222

223223
}
224224

@@ -263,7 +263,7 @@ TEST(workspaces, precompute4D_add) {
263263
expected.compile();
264264
expected.assemble();
265265
expected.compute();
266-
ASSERT_TENSOR_EQ(A, expected);
266+
ASSERT_TENSOR_EQ(expected, A);
267267
}
268268

269269
TEST(workspaces, precompute4D_multireduce) {
@@ -305,7 +305,7 @@ TEST(workspaces, precompute4D_multireduce) {
305305
expected.compile();
306306
expected.assemble();
307307
expected.compute();
308-
ASSERT_TENSOR_EQ(A, expected);
308+
ASSERT_TENSOR_EQ(expected, A);
309309
}
310310

311311
TEST(workspaces, precompute3D_TspV) {
@@ -344,7 +344,7 @@ TEST(workspaces, precompute3D_TspV) {
344344
expected.compile();
345345
expected.assemble();
346346
expected.compute();
347-
ASSERT_TENSOR_EQ(A, expected);
347+
ASSERT_TENSOR_EQ(expected, A);
348348

349349
}
350350

@@ -388,7 +388,7 @@ TEST(workspaces, precompute3D_multipleWS) {
388388
expected.compile();
389389
expected.assemble();
390390
expected.compute();
391-
ASSERT_TENSOR_EQ(A, expected);
391+
ASSERT_TENSOR_EQ(expected, A);
392392

393393
}
394394

@@ -431,6 +431,123 @@ TEST(workspaces, precompute3D_renamedIVars_TspV) {
431431
expected.compile();
432432
expected.assemble();
433433
expected.compute();
434-
ASSERT_TENSOR_EQ(A, expected);
434+
ASSERT_TENSOR_EQ(expected, A);
435435

436436
}
437+
438+
TEST(workspaces, DISABLED_tile_dotProduct_1) {
439+
// FIXME: Disabled because currently the precompute algorithm does not appropriately
440+
// optimize = from += when rewriting a statement for BOTH the producer and consumer
441+
// side of a where statement insertion.
442+
// Although always using += is CORRECT functionally, this fails the GPU tests since it
443+
// would result in scattering.
444+
int N = 1024;
445+
Tensor<double> A("A");
446+
Tensor<double> B("B", {N}, {Dense});
447+
Tensor<double> C("C", {N}, {Dense});
448+
449+
for (int i = 0; i < N; i++) {
450+
B.insert({i}, (double) i);
451+
C.insert({i}, (double) i);
452+
}
453+
454+
B.pack();
455+
C.pack();
456+
457+
IndexVar i("i");
458+
IndexVar i_bounded("i_bounded");
459+
IndexVar i0("i0"), i1("i1");
460+
IndexExpr BExpr = B(i);
461+
IndexExpr CExpr = C(i);
462+
IndexExpr precomputedExpr = (BExpr) * (CExpr);
463+
A() = precomputedExpr;
464+
465+
IndexStmt stmt = A.getAssignment().concretize();
466+
TensorVar B_new("B_new", Type(Float64, {(size_t)N}), taco::dense);
467+
TensorVar C_new("C_new", Type(Float64, {(size_t)N}), taco::dense);
468+
TensorVar precomputed("precomputed", Type(Float64, {(size_t)N}), taco::dense);
469+
470+
stmt = stmt.bound(i, i_bounded, (size_t)N, BoundType::MaxExact)
471+
.split(i_bounded, i0, i1, 32);
472+
stmt = stmt.precompute(precomputedExpr, i1, i1, precomputed);
473+
474+
cout << stmt << endl;
475+
cout << endl;
476+
477+
stmt = stmt.precompute(BExpr, i1, i1, B_new)
478+
.precompute(CExpr, i1, i1, C_new);
479+
480+
481+
stmt = stmt.concretize();
482+
cout << stmt << endl;
483+
484+
A.compile(stmt);
485+
A.assemble();
486+
A.compute();
487+
488+
Tensor<double> expected("expected");
489+
expected() = B(i) * C(i);
490+
expected.compile();
491+
expected.assemble();
492+
expected.compute();
493+
ASSERT_TENSOR_EQ(expected, A);
494+
}
495+
496+
TEST(workspaces, DISABLED_tile_dotProduct_2) {
497+
// FIXME: This is also currently disabled since split(...) scheduling commands
498+
// only split on the FIRST INSTANCE of an indexVar (assumes only one).
499+
// This is wrong if the indexVar is not renamed across iw_vars since an indexVar can
500+
// then occur on BOTH the consumer and producer side and should be split across both.
501+
502+
int N = 1024;
503+
Tensor<double> A("A");
504+
Tensor<double> B("B", {N}, {Dense});
505+
Tensor<double> C("C", {N}, {Dense});
506+
507+
for (int i = 0; i < N; i++) {
508+
B.insert({i}, (double) i);
509+
C.insert({i}, (double) i);
510+
}
511+
512+
B.pack();
513+
C.pack();
514+
515+
IndexVar i("i");
516+
IndexVar i_bounded("i_bounded");
517+
IndexVar i0("i0"), i1("i1");
518+
IndexExpr BExpr = B(i);
519+
IndexExpr CExpr = C(i);
520+
IndexExpr precomputedExpr = (BExpr) * (CExpr);
521+
A() = precomputedExpr;
522+
523+
IndexStmt stmt = A.getAssignment().concretize();
524+
TensorVar B_new("B_new", Type(Float64, {(size_t)N}), taco::dense);
525+
TensorVar C_new("C_new", Type(Float64, {(size_t)N}), taco::dense);
526+
TensorVar precomputed("precomputed", Type(Float64, {(size_t)N}), taco::dense);
527+
528+
stmt = stmt.precompute(precomputedExpr, i, i, precomputed);
529+
530+
cout << stmt << endl;
531+
cout << endl;
532+
533+
stmt = stmt.precompute(BExpr, i, i, B_new)
534+
.precompute(CExpr, i, i, C_new);
535+
536+
stmt = stmt.bound(i, i_bounded, (size_t)N, BoundType::MaxExact)
537+
.split(i_bounded, i0, i1, 32);
538+
539+
stmt = stmt.concretize();
540+
cout << stmt << endl;
541+
542+
A.compile(stmt);
543+
A.assemble();
544+
A.compute();
545+
546+
Tensor<double> expected("expected");
547+
expected() = B(i) * C(i);
548+
expected.compile();
549+
expected.assemble();
550+
expected.compute();
551+
ASSERT_TENSOR_EQ(expected, A);
552+
}
553+

0 commit comments

Comments
 (0)