@@ -97,7 +97,7 @@ class Test3 : public TestBase {
9797 static constexpr size_t mat_m = 16 ;
9898 static constexpr size_t mat_n = 64 ;
9999 static constexpr size_t mat_k = 32 ;
100- static constexpr size_t wg_m = 8 ;
100+ static constexpr size_t wg_m = 16 ;
101101 static constexpr size_t wg_n = 64 ;
102102 static constexpr size_t sg_m = 1 ;
103103 static constexpr size_t sg_n = 64 ;
@@ -205,7 +205,7 @@ class Test8 : public TestBase {
205205 static constexpr uint32_t global_kslicing = 2 ;
206206 static constexpr uint32_t local_kslicing = 1 ;
207207 static constexpr mem_layout layout_a = mem_layout::row_major;
208- static constexpr mem_layout layout_b = mem_layout::col_major ;
208+ static constexpr mem_layout layout_b = mem_layout::row_major ;
209209 using data_type_a = float ;
210210 using data_type_b = float ;
211211 using data_type_c = float ;
@@ -227,7 +227,6 @@ class Test9 : public TestBase {
227227 static constexpr uint32_t local_kslicing = 1 ;
228228 static constexpr mem_layout layout_a = mem_layout::row_major;
229229 static constexpr mem_layout layout_b = mem_layout::row_major;
230- static constexpr mma_engine engine = mma_engine::xmx;
231230 using data_type_a = float ;
232231 using data_type_b = float ;
233232 using data_type_c = float ;
@@ -245,10 +244,10 @@ class Test10 : public TestBase {
245244 static constexpr size_t sg_m = 32 ;
246245 static constexpr size_t sg_n = 64 ;
247246 static constexpr size_t sg_k = 8 ;
248- static constexpr uint32_t global_kslicing = 2 ;
247+ static constexpr uint32_t global_kslicing = 1 ;
249248 static constexpr uint32_t local_kslicing = 1 ;
250249 static constexpr mem_layout layout_a = mem_layout::row_major;
251- static constexpr mem_layout layout_b = mem_layout::col_major ;
250+ static constexpr mem_layout layout_b = mem_layout::row_major ;
252251 using data_type_a = float ;
253252 using data_type_b = float ;
254253 using data_type_c = float ;
@@ -258,9 +257,9 @@ class Test10 : public TestBase {
258257class Test11 : public TestBase {
259258 public:
260259 static constexpr size_t batch_size = 35 ;
261- static constexpr size_t mat_m = 4192 ;
262- static constexpr size_t mat_k = 1136 ;
263- static constexpr size_t mat_n = 688 ;
260+ static constexpr size_t mat_m = 4193 ;
261+ static constexpr size_t mat_k = 1134 ;
262+ static constexpr size_t mat_n = 686 ;
264263 static constexpr size_t wg_m = 256 ;
265264 static constexpr size_t wg_n = 256 ;
266265 static constexpr size_t sg_m = 32 ;
@@ -314,4 +313,4 @@ class result_validate {
314313 Test::layout_a,
315314 Test::layout_b);
316315 }
317- };
316+ };
0 commit comments