@@ -380,15 +380,16 @@ class MergerTest3T1LD : public MergerTestBase {
380380// /
381381// / Tests with both undef and dense input.
382382// /
383- class MergerTest3T1LU : public MergerTestBase {
383+
384+ class MergerTest4T1LU : public MergerTestBase {
384385protected:
385386 // Our three tensors (two inputs, one output).
386- const unsigned t0 = 0 , t1 = 1 , t2 = 2 ;
387+ const unsigned t0 = 0 , t1 = 1 , t2 = 2 , t3 = 3 ;
387388
388389 // Our single loop.
389390 const unsigned l0 = 0 ;
390391
391- MergerTest3T1LU () : MergerTestBase(3 , 1 ) {
392+ MergerTest4T1LU () : MergerTestBase(4 , 1 ) {
392393 // Tensor 0: undef input vector.
393394 merger.addExp (Kind::kTensor , t0, -1u );
394395 merger.setDimLevelFormat (t0, l0, DimLevelFormat (DimLvlType::kUndef ));
@@ -397,43 +398,110 @@ class MergerTest3T1LU : public MergerTestBase {
397398 merger.addExp (Kind::kTensor , t1, -1u );
398399 merger.setDimLevelFormat (t1, l0, DimLevelFormat (DimLvlType::kDense ));
399400
400- // Tensor 2: dense output vector.
401+ // Tensor 2: undef input vector.
401402 merger.addExp (Kind::kTensor , t2, -1u );
402- merger.setDimLevelFormat (t2, l0, DimLevelFormat (DimLvlType::kDense ));
403+ merger.setDimLevelFormat (t2, l0, DimLevelFormat (DimLvlType::kUndef ));
404+
405+ // Tensor 3: dense output vector.
406+ merger.addExp (Kind::kTensor , t3, -1u );
407+ merger.setDimLevelFormat (t3, l0, DimLevelFormat (DimLvlType::kDense ));
408+ }
409+ };
410+
411+ // /
412+ // / Tests with operation on sparse output.
413+ // /
414+
415+ class MergerTest3T1L_SO : public MergerTestBase {
416+ protected:
417+ // Our three tensors (two inputs, one output, one synthetic).
418+ const unsigned t0 = 0 , t1 = 1 , t2 = 2 , t3 = 3 ;
419+
420+ // Our single loop.
421+ const unsigned l0 = 0 ;
422+
423+ MergerTest3T1L_SO () : MergerTestBase(3 , 1 ) {
424+ merger.setHasSparseOut (true );
425+
426+ // Tensor 0: undef input vector.
427+ merger.addExp (Kind::kTensor , t0, -1u );
428+ merger.setDimLevelFormat (t0, l0, DimLevelFormat (DimLvlType::kUndef ));
429+
430+ // Tensor 1: undef input vector.
431+ merger.addExp (Kind::kTensor , t1, -1u );
432+ merger.setDimLevelFormat (t1, l0, DimLevelFormat (DimLvlType::kUndef ));
433+
434+ // Tensor 2: sparse output vector.
435+ merger.addExp (Kind::kTensor , t2, -1u );
436+ merger.setDimLevelFormat (t2, l0, DimLevelFormat (DimLvlType::kCompressed ));
403437 }
404438};
439+
405440} // namespace
406441
407- // / Vector multiplication (conjunction) of 2 vectors, i.e.;
408- // / a(i) = b(i) * c(i)
442+ // / Vector multiplication (conjunction) of 3 vectors, i.e.;
443+ // / a(i) = b(i) * c(i) * d(i)
409444// / which should form the single lattice point
410445// / {
411- // / lat( i_00_U i_01_D / (tensor_0 * tensor_1) )
446+ // / lat( i_00_U i_01_D i_02_U / (tensor_0 * tensor_1 * tensor2 ) )
412447// / }
413448// / after optimization, the dense dimesion should be kept, despite it appears
414- // / after the undef dimension
449+ // / in the middle
415450// / {
416- // / lat( i_01_D / (tensor_0 * tensor_1) )
451+ // / lat( i_01_D / (tensor_0 * tensor_1 * tensor2 ) )
417452// / }
418- #define IMPL_MERGER_TEST_CONJ (OP ) \
419- TEST_F (MergerTest3T1LU, vector_##OP) { \
420- auto e = OP##Expr (t0, t1); \
453+ #define IMPL_MERGER_TEST_CONJ_CONJ_UNDEF (CONJ1, CONJ2 ) \
454+ TEST_F (MergerTest4T1LU, vector_##CONJ1##_##CONJ2) { \
455+ auto em = CONJ1##Expr (t0, t1); \
456+ auto e = CONJ2##Expr (em, t2); \
421457 auto p0 = tensorPattern (t0); \
422458 auto p1 = tensorPattern (t1); \
459+ auto p2 = tensorPattern (t2); \
423460 auto s = merger.buildLattices (e, l0); \
424- \
425461 expectNumLatPoints (s, 1 ); \
426- expectLatPoint (s, lat (0 ), OP##Pattern (p0, p1), \
427- loopsToBits ({{l0, t0}, {l0, t1}})); \
428- \
462+ expectLatPoint (s, lat (0 ), CONJ2##Pattern (CONJ1##Pattern (p0, p1), p2), \
463+ loopsToBits ({{l0, t0}, {l0, t1}, {l0, t2}})); \
429464 s = merger.optimizeSet (s); \
430465 expectNumLatPoints (s, 1 ); \
431- expectLatPoint (s, lat (0 ), OP ##Pattern (p0, p1), loopsToBits ({{l0, t1}}), \
432- true ); \
466+ expectLatPoint (s, lat (0 ), CONJ2 ##Pattern (CONJ1## Pattern ( p0, p1), p2), \
467+ loopsToBits ({{l0, t1}}), true ); \
433468 }
434- FOREVERY_COMMON_CONJ_BINOP (IMPL_MERGER_TEST_CONJ)
435469
436- #undef IMPL_MERGER_TEST_CONJ
470+ FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP (IMPL_MERGER_TEST_CONJ_CONJ_UNDEF)
471+
472+ #undef IMPL_MERGER_TEST_CONJ_CONJ_UNDEF
473+
474+ // / Vector multiplication (conjunction) of 2 vectors, i.e.;
475+ // / o(i) = b(i) * c(i) * o(i)
476+ // / which should form the single lattice point (note how a synthetic tensor
477+ // / i_03_U is created for the sparse output)
478+ // / {
479+ // / lat( i_00_U i_01_U i_03_U / (tensor_0 * tensor_1 * output_tensor_2) )
480+ // / }
481+ // / after optimization, the synthetic tensor should be preserved.
482+ // / {
483+ // / lat( i_03_U / (tensor_0 * tensor_1 * output_tensor2) )
484+ // / }
485+ #define IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT (CONJ1, CONJ2 ) \
486+ TEST_F (MergerTest3T1L_SO, vector_##CONJ1##_##CONJ2) { \
487+ auto em = CONJ1##Expr (t0, t1); \
488+ auto e = CONJ2##Expr (em, t2); \
489+ auto p0 = tensorPattern (t0); \
490+ auto p1 = tensorPattern (t1); \
491+ auto p2 = tensorPattern (t2); \
492+ auto s = merger.buildLattices (e, l0); \
493+ expectNumLatPoints (s, 1 ); \
494+ expectLatPoint (s, lat (0 ), CONJ2##Pattern (CONJ1##Pattern (p0, p1), p2), \
495+ loopsToBits ({{l0, t0}, {l0, t1}, {l0, t3}})); \
496+ s = merger.optimizeSet (s); \
497+ expectNumLatPoints (s, 1 ); \
498+ expectLatPoint (s, lat (0 ), CONJ2##Pattern (CONJ1##Pattern (p0, p1), p2), \
499+ loopsToBits ({{l0, t3}}), true ); \
500+ }
501+
502+ FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP (IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT)
503+
504+ #undef IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT
437505
438506// / Vector addition (disjunction) of 2 vectors. i.e.;
439507// / a(i) = b(i) + c(i)
0 commit comments