Skip to content

Commit aa57ebc

Browse files
committed
Add in more multidim workspace tests
1 parent 54fba50 commit aa57ebc

File tree

1 file changed

+46
-2
lines changed

1 file changed

+46
-2
lines changed

test/tests-workspaces.cpp

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -254,8 +254,10 @@ TEST(workspaces, precompute4D_add) {
254254

255255

256256
IndexStmt stmt = A.getAssignment().concretize();
257-
TensorVar ws("ws", Type(Float64, {N, N, N, N}), Format{Dense, Dense, Dense, Dense});
258-
stmt = stmt.precompute(precomputedExpr, {i, j, k, l}, {i, j, k, l}, ws);
257+
TensorVar ws1("ws1", Type(Float64, {N, N, N, N}), Format{Dense, Dense, Dense, Dense});
258+
TensorVar ws2("ws2", Type(Float64, {N, N, N, N}), Format{Dense, Dense, Dense, Dense});
259+
stmt = stmt.precompute(precomputedExpr, {i, j, k, l}, {i, j, k, l}, ws1)
260+
.precompute(ws1(i, j, k, l) + D(i, j, k, l), {i, j, k, l}, {i, j, k ,l}, ws2);
259261

260262
A.compile(stmt.concretize());
261263
A.assemble();
@@ -269,6 +271,48 @@ TEST(workspaces, precompute4D_add) {
269271
ASSERT_TENSOR_EQ(A, expected);
270272
}
271273

274+
TEST(workspaces, precompute4D_multireduce) {
275+
int N = 16;
276+
Tensor<double> A("A", {N, N}, Format{Dense, Dense});
277+
Tensor<double> B("B", {N, N, N, N}, Format{Dense, Dense, Dense, Dense});
278+
Tensor<double> C("C", {N, N, N}, Format{Dense, Dense, Dense});
279+
Tensor<double> D("D", {N, N}, Format{Dense, Dense});
280+
281+
for (int i = 0; i < N; i++) {
282+
for (int j = 0; j < N; j++) {
283+
for (int k = 0; k < N; k++) {
284+
for (int l = 0; l < N; l++) {
285+
B.insert({i, j, k, l}, (double) k*l);
286+
C.insert({i, j, k}, (double) j * k);
287+
D.insert({i, j}, (double) i+j);
288+
}
289+
}
290+
}
291+
}
292+
293+
IndexVar i("i"), j("j"), k("k"), l("l"), m("m"), n("n");
294+
IndexExpr precomputedExpr = B(i, j, k, l) * C(k, l, m);
295+
A(i, j) = precomputedExpr * D(m, n);
296+
297+
298+
IndexStmt stmt = A.getAssignment().concretize();
299+
TensorVar ws1("ws1", Type(Float64, {N, N, N}), Format{Dense, Dense, Dense});
300+
TensorVar ws2("ws2", Type(Float64, {N, N}), Format{Dense, Dense});
301+
stmt = stmt.precompute(precomputedExpr, {i, j, m}, {i, j, m}, ws1)
302+
.precompute(ws1(i, j, m) * D(m, n), {i, j}, {i, j}, ws2);
303+
304+
A.compile(stmt.concretize());
305+
A.assemble();
306+
A.compute();
307+
308+
Tensor<double> expected("expected", {N, N}, Format{Dense, Dense});
309+
expected(i, j) = B(i, j, k, l) * C(k, l, m) * D(m, n);
310+
expected.compile();
311+
expected.assemble();
312+
expected.compute();
313+
ASSERT_TENSOR_EQ(A, expected);
314+
}
315+
272316
TEST(workspaces, precompute3D_MspV) {
273317
int N = 16;
274318
Tensor<double> A("A", {N, N}, Format{Dense, Dense});

0 commit comments

Comments
 (0)