@@ -442,98 +442,124 @@ func @test_transpose(%arg0: tensor<1x2x3xi32>) -> () {
442442// -----
443443
444444// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
445- // CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
446- // CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0)>
445+ // CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (0, d1)>
446+ // CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, 0 )>
447447
448448// CHECK-LABEL: @reduce_float
449449// CHECK-SAME: [[ARG0:%.+]]: tensor<5x4xf32>
450450func @reduce_float (%arg0: tensor <5 x4 xf32 >) -> () {
451- // CHECK: [[INIT:%.+]] = linalg.init_tensor [4]
451+ // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 4]
452452 // CHECK: [[CST0:%.+]] = constant 0.0
453453 // CHECK: [[FILL:%.+]] = linalg.fill([[INIT]], [[CST0]])
454- // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]} ins([[ARG0]] : tensor<5x4xf32>) outs([[FILL]] : tensor<4xf32 >)
454+ // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]} ins([[ARG0]] : tensor<5x4xf32>) outs([[FILL]] : tensor<1x4xf32 >)
455455 // CHECK: ^bb0(%arg1: f32, %arg2: f32)
456456 // CHECK: [[RES:%.+]] = addf %arg1, %arg2 : f32
457457 // CHECK: linalg.yield [[RES]] : f32
458- %0 = " tosa.reduce_sum" (%arg0 ) {axis = 0 : i64 } : (tensor <5 x4 xf32 >) -> tensor <4 x f32 >
458+ %0 = " tosa.reduce_sum" (%arg0 ) {axis = 0 : i64 } : (tensor <5 x4 xf32 >) -> tensor <1 x 4 x f32 >
459459
460- // CHECK: [[INIT:%.+]] = linalg.init_tensor [5]
460+ // CHECK: [[INIT:%.+]] = linalg.init_tensor [5, 1 ]
461461 // CHECK: [[CST0:%.+]] = constant 0.0
462462 // CHECK: [[FILL:%.+]] = linalg.fill([[INIT]], [[CST0]])
463- // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP2]]], iterator_types = ["parallel", "reduction"]} ins([[ARG0]] : tensor<5x4xf32>) outs([[FILL]] : tensor<5xf32 >)
463+ // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP2]]], iterator_types = ["parallel", "reduction"]} ins([[ARG0]] : tensor<5x4xf32>) outs([[FILL]] : tensor<5x1xf32 >)
464464 // CHECK: ^bb0(%arg1: f32, %arg2: f32)
465465 // CHECK: [[RES:%.+]] = addf %arg1, %arg2 : f32
466466 // CHECK: linalg.yield [[RES]] : f32
467- %1 = " tosa.reduce_sum" (%arg0 ) {axis = 1 : i64 } : (tensor <5 x4 xf32 >) -> tensor <5 x f32 >
467+ %1 = " tosa.reduce_sum" (%arg0 ) {axis = 1 : i64 } : (tensor <5 x4 xf32 >) -> tensor <5 x 1 x f32 >
468468
469469 // CHECK: constant 1.0
470470 // CHECK: linalg.fill
471471 // CHECK: linalg.generic
472472 // CHECK: mulf
473- %2 = " tosa.reduce_prod" (%arg0 ) {axis = 0 : i64 } : (tensor <5 x4 xf32 >) -> tensor <4 x f32 >
473+ %2 = " tosa.reduce_prod" (%arg0 ) {axis = 0 : i64 } : (tensor <5 x4 xf32 >) -> tensor <1 x 4 x f32 >
474474
475475 // CHECK: constant 3.40282347E+38 : f32
476476 // CHECK: linalg.fill
477477 // CHECK: linalg.generic
478478 // CHECK: cmpf olt
479479 // CHECK: select
480- %3 = " tosa.reduce_min" (%arg0 ) {axis = 0 : i64 } : (tensor <5 x4 xf32 >) -> tensor <4 x f32 >
480+ %3 = " tosa.reduce_min" (%arg0 ) {axis = 0 : i64 } : (tensor <5 x4 xf32 >) -> tensor <1 x 4 x f32 >
481481
482482 // CHECK: constant -3.40282347E+38 : f32
483483 // CHECK: linalg.fill
484484 // CHECK: linalg.generic
485485 // CHECK: cmpf ogt
486486 // CHECK: select
487- %4 = " tosa.reduce_max" (%arg0 ) {axis = 0 : i64 } : (tensor <5 x4 xf32 >) -> tensor <4 x f32 >
487+ %4 = " tosa.reduce_max" (%arg0 ) {axis = 0 : i64 } : (tensor <5 x4 xf32 >) -> tensor <1 x 4 x f32 >
488488 return
489489}
490490
491491// -----
492492
493493// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
494- // CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
495- // CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0)>
494+ // CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (0, d1)>
495+ // CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, 0 )>
496496
497497// CHECK-LABEL: @reduce_int
498498// CHECK-SAME: [[ARG0:%.+]]: tensor<5x4xi32>
499499func @reduce_int (%arg0: tensor <5 x4 xi32 >) -> () {
500- // CHECK: [[INIT:%.+]] = linalg.init_tensor [4]
500+ // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 4]
501501 // CHECK: [[CST0:%.+]] = constant 0
502502 // CHECK: [[FILL:%.+]] = linalg.fill([[INIT]], [[CST0]])
503- // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]} ins([[ARG0]] : tensor<5x4xi32>) outs([[FILL]] : tensor<4xi32 >)
503+ // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]} ins([[ARG0]] : tensor<5x4xi32>) outs([[FILL]] : tensor<1x4xi32 >)
504504 // CHECK: ^bb0(%arg1: i32, %arg2: i32)
505505 // CHECK: [[RES:%.+]] = addi %arg1, %arg2 : i32
506506 // CHECK: linalg.yield [[RES]] : i32
507- %0 = " tosa.reduce_sum" (%arg0 ) {axis = 0 : i64 } : (tensor <5 x4 xi32 >) -> tensor <4 x i32 >
507+ %0 = " tosa.reduce_sum" (%arg0 ) {axis = 0 : i64 } : (tensor <5 x4 xi32 >) -> tensor <1 x 4 x i32 >
508508
509- // CHECK: [[INIT:%.+]] = linalg.init_tensor [5]
509+ // CHECK: [[INIT:%.+]] = linalg.init_tensor [5, 1 ]
510510 // CHECK: [[CST0:%.+]] = constant 0
511511 // CHECK: [[FILL:%.+]] = linalg.fill([[INIT]], [[CST0]])
512- // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP2]]], iterator_types = ["parallel", "reduction"]} ins([[ARG0]] : tensor<5x4xi32>) outs([[FILL]] : tensor<5xi32 >)
512+ // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP2]]], iterator_types = ["parallel", "reduction"]} ins([[ARG0]] : tensor<5x4xi32>) outs([[FILL]] : tensor<5x1xi32 >)
513513 // CHECK: ^bb0(%arg1: i32, %arg2: i32)
514514 // CHECK: [[RES:%.+]] = addi %arg1, %arg2 : i32
515515 // CHECK: linalg.yield [[RES]] : i32
516- %1 = " tosa.reduce_sum" (%arg0 ) {axis = 1 : i64 } : (tensor <5 x4 xi32 >) -> tensor <5 x i32 >
516+ %1 = " tosa.reduce_sum" (%arg0 ) {axis = 1 : i64 } : (tensor <5 x4 xi32 >) -> tensor <5 x 1 x i32 >
517517
518518 // CHECK: constant 1
519519 // CHECK: linalg.fill
520520 // CHECK: linalg.generic
521521 // CHECK: muli
522- %2 = " tosa.reduce_prod" (%arg0 ) {axis = 0 : i64 } : (tensor <5 x4 xi32 >) -> tensor <4 x i32 >
522+ %2 = " tosa.reduce_prod" (%arg0 ) {axis = 0 : i64 } : (tensor <5 x4 xi32 >) -> tensor <1 x 4 x i32 >
523523
524524 // CHECK: constant 2147483647 : i32
525525 // CHECK: linalg.fill
526526 // CHECK: linalg.generic
527527 // CHECK: cmpi slt
528528 // CHECK: select
529- %3 = " tosa.reduce_min" (%arg0 ) {axis = 0 : i64 } : (tensor <5 x4 xi32 >) -> tensor <4 x i32 >
529+ %3 = " tosa.reduce_min" (%arg0 ) {axis = 0 : i64 } : (tensor <5 x4 xi32 >) -> tensor <1 x 4 x i32 >
530530
531531 // CHECK: constant -2147483648 : i32
532532 // CHECK: linalg.fill
533533 // CHECK: linalg.generic
534534 // CHECK: cmpi sgt
535535 // CHECK: select
536- %4 = " tosa.reduce_max" (%arg0 ) {axis = 0 : i64 } : (tensor <5 x4 xi32 >) -> tensor <4 xi32 >
536+ %4 = " tosa.reduce_max" (%arg0 ) {axis = 0 : i64 } : (tensor <5 x4 xi32 >) -> tensor <1 x4 xi32 >
537+ return
538+ }
539+
540+ // -----
541+
542+ // CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
543+ // CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (0, d1)>
544+
545+ // CHECK-LABEL: @reduce_bool
546+ // CHECK-SAME: [[ARG0:%.+]]: tensor<5x4xi1>
547+ func @reduce_bool (%arg0: tensor <5 x4 xi1 >) -> () {
548+ // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 4]
549+ // CHECK: [[CST0:%.+]] = constant true
550+ // CHECK: [[FILL:%.+]] = linalg.fill([[INIT]], [[CST0]])
551+ // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]} ins([[ARG0]] : tensor<5x4xi1>) outs([[FILL]] : tensor<1x4xi1>)
552+ // CHECK: ^bb0(%arg1: i1, %arg2: i1)
553+ // CHECK: [[RES:%.+]] = and %arg1, %arg2 : i1
554+ // CHECK: linalg.yield [[RES]] : i1
555+ %0 = " tosa.reduce_all" (%arg0 ) {axis = 0 : i64 } : (tensor <5 x4 xi1 >) -> tensor <1 x4 xi1 >
556+
557+ // CHECK: constant false
558+ // CHECK: linalg.fill
559+ // CHECK: linalg.generic
560+ // CHECK: or
561+ %1 = " tosa.reduce_any" (%arg0 ) {axis = 0 : i64 } : (tensor <5 x4 xi1 >) -> tensor <1 x4 xi1 >
562+
537563 return
538564}
539565
0 commit comments