@@ -254,8 +254,10 @@ TEST(workspaces, precompute4D_add) {
254254
255255
256256 IndexStmt stmt = A.getAssignment ().concretize ();
257- TensorVar ws (" ws" , Type (Float64, {N, N, N, N}), Format{Dense, Dense, Dense, Dense});
258- stmt = stmt.precompute (precomputedExpr, {i, j, k, l}, {i, j, k, l}, ws);
257+ TensorVar ws1 (" ws1" , Type (Float64, {N, N, N, N}), Format{Dense, Dense, Dense, Dense});
258+ TensorVar ws2 (" ws2" , Type (Float64, {N, N, N, N}), Format{Dense, Dense, Dense, Dense});
259+ stmt = stmt.precompute (precomputedExpr, {i, j, k, l}, {i, j, k, l}, ws1)
260+ .precompute (ws1 (i, j, k, l) + D (i, j, k, l), {i, j, k, l}, {i, j, k ,l}, ws2);
259261
260262 A.compile (stmt.concretize ());
261263 A.assemble ();
@@ -269,6 +271,48 @@ TEST(workspaces, precompute4D_add) {
269271 ASSERT_TENSOR_EQ (A, expected);
270272}
271273
274+ TEST (workspaces, precompute4D_multireduce) {
275+ int N = 16 ;
276+ Tensor<double > A (" A" , {N, N}, Format{Dense, Dense});
277+ Tensor<double > B (" B" , {N, N, N, N}, Format{Dense, Dense, Dense, Dense});
278+ Tensor<double > C (" C" , {N, N, N}, Format{Dense, Dense, Dense});
279+ Tensor<double > D (" D" , {N, N}, Format{Dense, Dense});
280+
281+ for (int i = 0 ; i < N; i++) {
282+ for (int j = 0 ; j < N; j++) {
283+ for (int k = 0 ; k < N; k++) {
284+ for (int l = 0 ; l < N; l++) {
285+ B.insert ({i, j, k, l}, (double ) k*l);
286+ C.insert ({i, j, k}, (double ) j * k);
287+ D.insert ({i, j}, (double ) i+j);
288+ }
289+ }
290+ }
291+ }
292+
293+ IndexVar i (" i" ), j (" j" ), k (" k" ), l (" l" ), m (" m" ), n (" n" );
294+ IndexExpr precomputedExpr = B (i, j, k, l) * C (k, l, m);
295+ A (i, j) = precomputedExpr * D (m, n);
296+
297+
298+ IndexStmt stmt = A.getAssignment ().concretize ();
299+ TensorVar ws1 (" ws1" , Type (Float64, {N, N, N}), Format{Dense, Dense, Dense});
300+ TensorVar ws2 (" ws2" , Type (Float64, {N, N}), Format{Dense, Dense});
301+ stmt = stmt.precompute (precomputedExpr, {i, j, m}, {i, j, m}, ws1)
302+ .precompute (ws1 (i, j, m) * D (m, n), {i, j}, {i, j}, ws2);
303+
304+ A.compile (stmt.concretize ());
305+ A.assemble ();
306+ A.compute ();
307+
308+ Tensor<double > expected (" expected" , {N, N}, Format{Dense, Dense});
309+ expected (i, j) = B (i, j, k, l) * C (k, l, m) * D (m, n);
310+ expected.compile ();
311+ expected.assemble ();
312+ expected.compute ();
313+ ASSERT_TENSOR_EQ (A, expected);
314+ }
315+
272316TEST (workspaces, precompute3D_MspV) {
273317 int N = 16 ;
274318 Tensor<double > A (" A" , {N, N}, Format{Dense, Dense});
0 commit comments