|
1 | 1 | // RUN: mlir-opt %s -convert-vector-to-arm-sme -split-input-file -allow-unregistered-dialect | FileCheck %s |
2 | 2 |
|
3 | 3 | //===----------------------------------------------------------------------===// |
4 | | -// vector.transfer_read (with in-flight transpose) |
| 4 | +// vector.transfer_read |
5 | 5 | //===----------------------------------------------------------------------===// |
6 | 6 |
|
7 | | -// CHECK-LABEL: @transfer_read_2d_transpose_i8 |
8 | | -// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xi8>, vector<[16]x[16]xi8> |
9 | | -func.func @transfer_read_2d_transpose_i8(%src : memref<?x?xi8>) { |
| 7 | +// CHECK-LABEL: @transfer_read_2d_i8 |
| 8 | +// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xi8>, vector<[16]x[16]xi8> |
| 9 | +func.func @transfer_read_2d_i8(%src : memref<?x?xi8>) { |
10 | 10 | %c0 = arith.constant 0 : index |
11 | 11 | %pad = arith.constant 0 : i8 |
12 | | - %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xi8>, vector<[16]x[16]xi8> |
| 12 | + %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xi8>, vector<[16]x[16]xi8> |
13 | 13 | "prevent.dce"(%0) : (vector<[16]x[16]xi8>) -> () |
14 | 14 | return |
15 | 15 | } |
16 | 16 |
|
17 | 17 | // ----- |
18 | 18 |
|
19 | | -// CHECK-LABEL: @transfer_read_2d_transpose_i16 |
20 | | -// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xi16>, vector<[8]x[8]xi16> |
21 | | -func.func @transfer_read_2d_transpose_i16(%src : memref<?x?xi16>) { |
| 19 | +// CHECK-LABEL: @transfer_read_2d_i16 |
| 20 | +// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xi16>, vector<[8]x[8]xi16> |
| 21 | +func.func @transfer_read_2d_i16(%src : memref<?x?xi16>) { |
22 | 22 | %c0 = arith.constant 0 : index |
23 | 23 | %pad = arith.constant 0 : i16 |
24 | | - %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xi16>, vector<[8]x[8]xi16> |
| 24 | + %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xi16>, vector<[8]x[8]xi16> |
25 | 25 | "prevent.dce"(%0) : (vector<[8]x[8]xi16>) -> () |
26 | 26 | return |
27 | 27 | } |
28 | 28 |
|
29 | 29 | // ----- |
30 | 30 |
|
31 | | -// CHECK-LABEL: @transfer_read_2d_transpose_i32 |
32 | | -// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32> |
33 | | -func.func @transfer_read_2d_transpose_i32(%src : memref<?x?xi32>) { |
| 31 | +// CHECK-LABEL: @transfer_read_2d_i32 |
| 32 | +// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xi32>, vector<[4]x[4]xi32> |
| 33 | +func.func @transfer_read_2d_i32(%src : memref<?x?xi32>) { |
34 | 34 | %c0 = arith.constant 0 : index |
35 | 35 | %pad = arith.constant 0 : i32 |
36 | | - %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xi32>, vector<[4]x[4]xi32> |
| 36 | + %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xi32>, vector<[4]x[4]xi32> |
37 | 37 | "prevent.dce"(%0) : (vector<[4]x[4]xi32>) -> () |
38 | 38 | return |
39 | 39 | } |
40 | 40 |
|
41 | 41 | // ----- |
42 | 42 |
|
43 | | -// CHECK-LABEL: @transfer_read_2d_transpose_i64 |
44 | | -// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xi64>, vector<[2]x[2]xi64> |
45 | | -func.func @transfer_read_2d_transpose_i64(%src : memref<?x?xi64>) { |
| 43 | +// CHECK-LABEL: @transfer_read_2d_i64 |
| 44 | +// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xi64>, vector<[2]x[2]xi64> |
| 45 | +func.func @transfer_read_2d_i64(%src : memref<?x?xi64>) { |
46 | 46 | %c0 = arith.constant 0 : index |
47 | 47 | %pad = arith.constant 0 : i64 |
48 | | - %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xi64>, vector<[2]x[2]xi64> |
| 48 | + %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xi64>, vector<[2]x[2]xi64> |
49 | 49 | "prevent.dce"(%0) : (vector<[2]x[2]xi64>) -> () |
50 | 50 | return |
51 | 51 | } |
52 | 52 |
|
53 | 53 | // ----- |
54 | 54 |
|
55 | | -// CHECK-LABEL: @transfer_read_2d_transpose_i128 |
56 | | -// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xi128>, vector<[1]x[1]xi128> |
57 | | -func.func @transfer_read_2d_transpose_i128(%src : memref<?x?xi128>) { |
| 55 | +// CHECK-LABEL: @transfer_read_2d_i128 |
| 56 | +// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xi128>, vector<[1]x[1]xi128> |
| 57 | +func.func @transfer_read_2d_i128(%src : memref<?x?xi128>) { |
58 | 58 | %c0 = arith.constant 0 : index |
59 | 59 | %pad = arith.constant 0 : i128 |
60 | | - %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xi128>, vector<[1]x[1]xi128> |
| 60 | + %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xi128>, vector<[1]x[1]xi128> |
61 | 61 | "prevent.dce"(%0) : (vector<[1]x[1]xi128>) -> () |
62 | 62 | return |
63 | 63 | } |
64 | 64 |
|
65 | 65 | // ----- |
66 | 66 |
|
67 | | -// CHECK-LABEL: @transfer_read_2d_transpose_f16 |
68 | | -// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xf16>, vector<[8]x[8]xf16> |
69 | | -func.func @transfer_read_2d_transpose_f16(%src : memref<?x?xf16>) { |
| 67 | +// CHECK-LABEL: @transfer_read_2d_f16 |
| 68 | +// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xf16>, vector<[8]x[8]xf16> |
| 69 | +func.func @transfer_read_2d_f16(%src : memref<?x?xf16>) { |
70 | 70 | %c0 = arith.constant 0 : index |
71 | 71 | %pad = arith.constant 0.0 : f16 |
72 | | - %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xf16>, vector<[8]x[8]xf16> |
| 72 | + %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xf16>, vector<[8]x[8]xf16> |
73 | 73 | "prevent.dce"(%0) : (vector<[8]x[8]xf16>) -> () |
74 | 74 | return |
75 | 75 | } |
76 | 76 |
|
77 | 77 | // ----- |
78 | 78 |
|
79 | | -// CHECK-LABEL: @transfer_read_2d_transpose_bf16 |
80 | | -// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16> |
81 | | -func.func @transfer_read_2d_transpose_bf16(%src : memref<?x?xbf16>) { |
| 79 | +// CHECK-LABEL: @transfer_read_2d_bf16 |
| 80 | +// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xbf16>, vector<[8]x[8]xbf16> |
| 81 | +func.func @transfer_read_2d_bf16(%src : memref<?x?xbf16>) { |
82 | 82 | %c0 = arith.constant 0 : index |
83 | 83 | %pad = arith.constant 0.0 : bf16 |
84 | | - %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xbf16>, vector<[8]x[8]xbf16> |
| 84 | + %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xbf16>, vector<[8]x[8]xbf16> |
85 | 85 | "prevent.dce"(%0) : (vector<[8]x[8]xbf16>) -> () |
86 | 86 | return |
87 | 87 | } |
88 | 88 |
|
89 | 89 | // ----- |
90 | 90 |
|
91 | | -// CHECK-LABEL: @transfer_read_2d_transpose_f32 |
92 | | -// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32> |
93 | | -func.func @transfer_read_2d_transpose_f32(%src : memref<?x?xf32>) { |
| 91 | +// CHECK-LABEL: @transfer_read_2d_f32 |
| 92 | +// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xf32>, vector<[4]x[4]xf32> |
| 93 | +func.func @transfer_read_2d_f32(%src : memref<?x?xf32>) { |
94 | 94 | %c0 = arith.constant 0 : index |
95 | 95 | %pad = arith.constant 0.0 : f32 |
96 | | - %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32> |
| 96 | + %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32> |
97 | 97 | "prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> () |
98 | 98 | return |
99 | 99 | } |
100 | 100 |
|
101 | 101 | // ----- |
102 | 102 |
|
103 | | -// CHECK-LABEL: @transfer_read_2d_transpose_f64 |
104 | | -// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xf64>, vector<[2]x[2]xf64> |
105 | | -func.func @transfer_read_2d_transpose_f64(%src : memref<?x?xf64>) { |
| 103 | +// CHECK-LABEL: @transfer_read_2d_f64 |
| 104 | +// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xf64>, vector<[2]x[2]xf64> |
| 105 | +func.func @transfer_read_2d_f64(%src : memref<?x?xf64>) { |
106 | 106 | %c0 = arith.constant 0 : index |
107 | 107 | %pad = arith.constant 0.0 : f64 |
108 | | - %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xf64>, vector<[2]x[2]xf64> |
| 108 | + %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xf64>, vector<[2]x[2]xf64> |
109 | 109 | "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> () |
110 | 110 | return |
111 | 111 | } |
112 | 112 |
|
113 | 113 | // ----- |
114 | 114 |
|
115 | | -// CHECK-LABEL: @transfer_read_2d__bad_type |
116 | | -// CHECK-NOT: arm_sme.tile_load |
117 | | -// CHECK: vector.transfer_read |
118 | | -func.func @transfer_read_2d__bad_type(%src : memref<?x?xf64>) { |
| 115 | +// CHECK-LABEL: @transfer_read_2d_with_mask_i16 |
| 116 | +// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}], {{.*}}, {{.*}} : memref<?x?xi16>, vector<[8]x[8]xi16> |
| 117 | +func.func @transfer_read_2d_with_mask_i16(%src : memref<?x?xi16>, %mask : vector<[8]x[8]xi1>) { |
119 | 118 | %c0 = arith.constant 0 : index |
120 | | - %pad = arith.constant 0.0 : f64 |
121 | | - %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [false, false]} : memref<?x?xf64>, vector<[4]x[4]xf64> |
122 | | - "prevent.dce"(%0) : (vector<[4]x[4]xf64>) -> () |
| 119 | + %pad = arith.constant 0 : i16 |
| 120 | + %0 = vector.transfer_read %src[%c0, %c0], %pad, %mask {in_bounds = [true, true]} : memref<?x?xi16>, vector<[8]x[8]xi16> |
| 121 | + "prevent.dce"(%0) : (vector<[8]x[8]xi16>) -> () |
123 | 122 | return |
124 | 123 | } |
125 | 124 |
|
126 | 125 | // ----- |
127 | 126 |
|
128 | | -// CHECK-LABEL: @transfer_read_2d__non_memref_type |
129 | | -// CHECK-NOT: arm_sme.tile_load |
130 | | -// CHECK: vector.transfer_read |
131 | | -func.func @transfer_read_2d__non_memref_type(%src : tensor<?x?xf64>) { |
| 127 | +/// in-flight transpose |
| 128 | + |
| 129 | +// CHECK-LABEL: @transfer_read_2d_transpose_i8 |
| 130 | +// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xi8>, vector<[16]x[16]xi8> |
| 131 | +func.func @transfer_read_2d_transpose_i8(%src : memref<?x?xi8>) { |
132 | 132 | %c0 = arith.constant 0 : index |
133 | | - %pad = arith.constant 0.0 : f64 |
134 | | - %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : tensor<?x?xf64>, vector<[2]x[2]xf64> |
135 | | - "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> () |
| 133 | + %pad = arith.constant 0 : i8 |
| 134 | + %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xi8>, vector<[16]x[16]xi8> |
| 135 | + "prevent.dce"(%0) : (vector<[16]x[16]xi8>) -> () |
136 | 136 | return |
137 | 137 | } |
138 | 138 |
|
139 | 139 | // ----- |
140 | 140 |
|
141 | | -// CHECK-LABEL: @transfer_read_2d__bad_transfer_rank |
| 141 | +// CHECK-LABEL: @transfer_read_2d_transpose_with_mask_f32 |
| 142 | +// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32> |
| 143 | +func.func @transfer_read_2d_transpose_with_mask_f32(%src : memref<?x?xf32>, %mask : vector<[4]x[4]xi1>) { |
| 144 | + %c0 = arith.constant 0 : index |
| 145 | + %pad = arith.constant 0.0 : f32 |
| 146 | + %0 = vector.transfer_read %src[%c0, %c0], %pad, %mask {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32> |
| 147 | + "prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> () |
| 148 | + return |
| 149 | +} |
| 150 | + |
| 151 | +// ----- |
| 152 | + |
| 153 | +// CHECK-LABEL: @transfer_read_2d__bad_type |
142 | 154 | // CHECK-NOT: arm_sme.tile_load |
143 | 155 | // CHECK: vector.transfer_read |
144 | | -func.func @transfer_read_2d__bad_transfer_rank(%src : memref<?x?xf64>) { |
| 156 | +func.func @transfer_read_2d__bad_type(%src : memref<?x?xf64>) { |
145 | 157 | %c0 = arith.constant 0 : index |
146 | 158 | %pad = arith.constant 0.0 : f64 |
147 | | - %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]} : memref<?x?xf64>, vector<[2]xf64> |
148 | | - "prevent.dce"(%0) : (vector<[2]xf64>) -> () |
| 159 | + %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [false, false]} : memref<?x?xf64>, vector<[4]x[4]xf64> |
| 160 | + "prevent.dce"(%0) : (vector<[4]x[4]xf64>) -> () |
149 | 161 | return |
150 | 162 | } |
151 | 163 |
|
152 | 164 | // ----- |
153 | 165 |
|
154 | | -// CHECK-LABEL: @transfer_read_2d__unsupported_mask |
| 166 | +// CHECK-LABEL: @transfer_read_2d__non_memref_type |
155 | 167 | // CHECK-NOT: arm_sme.tile_load |
156 | 168 | // CHECK: vector.transfer_read |
157 | | -func.func @transfer_read_2d__unsupported_mask(%src : memref<?x?xf64>, %mask : vector<[2]x[2]xi1>) { |
| 169 | +func.func @transfer_read_2d__non_memref_type(%src : tensor<?x?xf64>) { |
158 | 170 | %c0 = arith.constant 0 : index |
159 | 171 | %pad = arith.constant 0.0 : f64 |
160 | | - %0 = vector.transfer_read %src[%c0, %c0], %pad, %mask {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xf64>, vector<[2]x[2]xf64> |
| 172 | + %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : tensor<?x?xf64>, vector<[2]x[2]xf64> |
161 | 173 | "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> () |
162 | 174 | return |
163 | 175 | } |
164 | 176 |
|
165 | 177 | // ----- |
166 | 178 |
|
167 | | -/// transfer_read with identity map should be lowered to vector.load by |
168 | | -/// TransferReadToVectorLoadLowering and then arm_sme.tile_load by |
169 | | -/// VectorLoadToArmSMELowering. |
170 | | - |
171 | | -// CHECK-LABEL: @transfer_read_2d__non_permuting_map |
| 179 | +// CHECK-LABEL: @transfer_read_2d__bad_transfer_rank |
172 | 180 | // CHECK-NOT: arm_sme.tile_load |
173 | 181 | // CHECK: vector.transfer_read |
174 | | -func.func @transfer_read_2d__non_permuting_map(%src : memref<?x?xf64>) { |
| 182 | +func.func @transfer_read_2d__bad_transfer_rank(%src : memref<?x?xf64>) { |
175 | 183 | %c0 = arith.constant 0 : index |
176 | 184 | %pad = arith.constant 0.0 : f64 |
177 | | - %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d0, d1)>, in_bounds = [true, true]} : memref<?x?xf64>, vector<[2]x[2]xf64> |
178 | | - "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> () |
| 185 | + %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]} : memref<?x?xf64>, vector<[2]xf64> |
| 186 | + "prevent.dce"(%0) : (vector<[2]xf64>) -> () |
179 | 187 | return |
180 | 188 | } |
181 | 189 |
|
|
0 commit comments