@@ -814,7 +814,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr
814814 %BB_DOT = ttg.local_load %BB : !ttg.memdesc <16 x16 xf16 , #shared0 , #smem > -> tensor <16 x16 xf16 , #dot_operand_b >
815815 %cst0 = arith.constant dense <0.000000e+00 > : tensor <16 x16 xf32 , #dpas0 >
816816
817- // CHECK-COUNT-2: llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f (%{{.*}}, %{{.*}}, %{{.*}}) {{.*}} : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32>
817+ // CHECK-COUNT-2: llvm.call spir_funccc @_Z45__spirv_SubgroupMatrixMultiplyAccumulateINTELiDv8_sDv8_iDv8_fi (%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} ) {{.*}} : (i32, vector<8xi16>, vector<8xi32>, vector<8xf32>, i32 ) -> vector<8xf32>
818818 %D = tt.dot %AA_DOT , %BB_DOT , %cst0 : tensor <16 x16 xf16 , #dot_operand_a > * tensor <16 x16 xf16 , #dot_operand_b > -> tensor <16 x16 xf32 , #dpas0 >
819819
820820 tt.return
@@ -964,7 +964,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
964964 %a_mat = ttg.local_load %a : !ttg.memdesc <128 x32 xf16 , #shared , #smem > -> tensor <128 x32 xf16 , #dot_operand_a >
965965 %b_mat = ttg.local_load %b : !ttg.memdesc <32 x256 xf16 , #shared , #smem > -> tensor <32 x256 xf16 , #dot_operand_b >
966966
967- // CHECK-COUNT-128: llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f (%{{.*}}, %{{.*}}, %{{.*}}) {{.*}} : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32>
967+ // CHECK-COUNT-128: llvm.call spir_funccc @_Z45__spirv_SubgroupMatrixMultiplyAccumulateINTELiDv8_sDv8_iDv8_fi (%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} ) {{.*}} : (i32, vector<8xi16>, vector<8xi32>, vector<8xf32>, i32 ) -> vector<8xf32>
968968 %28 = tt.dot %a_mat , %b_mat , %cst : tensor <128 x32 xf16 , #dot_operand_a > * tensor <32 x256 xf16 , #dot_operand_b > -> tensor <128 x256 xf32 , #dpas >
969969 %38 = ttg.convert_layout %28 : tensor <128 x256 xf32 , #dpas > -> tensor <128 x256 xf32 , #blocked >
970970
@@ -991,7 +991,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
991991 %a_mat = ttg.local_load %a : !ttg.memdesc <32 x64 xf16 , #shared0 , #smem > -> tensor <32 x64 xf16 , #dot_operand_a >
992992 %b_mat = ttg.local_load %b : !ttg.memdesc <64 x64 xf16 , #shared1 , #smem > -> tensor <64 x64 xf16 , #dot_operand_b >
993993
994- // CHECK-COUNT-16: llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f (%{{.*}}, %{{.*}}, %{{.*}}) {{.*}} : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32>
994+ // CHECK-COUNT-16: llvm.call spir_funccc @_Z45__spirv_SubgroupMatrixMultiplyAccumulateINTELiDv8_sDv8_iDv8_fi (%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} ) {{.*}} : (i32, vector<8xi16>, vector<8xi32>, vector<8xf32>, i32 ) -> vector<8xf32>
995995 %28 = tt.dot %a_mat , %b_mat , %cst : tensor <32 x64 xf16 , #dot_operand_a > * tensor <64 x64 xf16 , #dot_operand_b > -> tensor <32 x64 xf32 , #dpas >
996996 %38 = ttg.convert_layout %28 : tensor <32 x64 xf32 , #dpas > -> tensor <32 x64 xf32 , #blocked >
997997 %30 = tt.splat %ptr : !tt.ptr <f32 > -> tensor <32 x1 x!tt.ptr <f32 >, #blocked >
@@ -1040,7 +1040,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
10401040 %a_mat = ttg.local_load %a : !ttg.memdesc <32 x16 xf32 , #shared , #smem > -> tensor <32 x16 xf32 , #dot_operand_a >
10411041 %b_mat = ttg.local_load %b : !ttg.memdesc <16 x32 xf32 , #shared , #smem > -> tensor <16 x32 xf32 , #dot_operand_b >
10421042
1043- // CHECK-COUNT-2: llvm.call spir_funccc @_Z39intel_sub_group_tf32_tf32_matrix_mad_k8Dv8_fS_S_ (%{{.*}}, %{{.*}}, %{{.*}}) {{.*}} : (vector<8xf32>, vector<8xf32>, vector<8xf32>) -> vector<8xf32>
1043+ // CHECK-COUNT-2: llvm.call spir_funccc @_Z45__spirv_SubgroupMatrixMultiplyAccumulateINTELiDv8_fS_S_i (%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} ) {{.*}} : (i32, vector<8xf32>, vector<8xf32>, vector<8xf32>, i32 ) -> vector<8xf32>
10441044 %28 = tt.dot %a_mat , %b_mat , %cst , inputPrecision = tf32 : tensor <32 x16 xf32 , #dot_operand_a > * tensor <16 x32 xf32 , #dot_operand_b > -> tensor <32 x32 xf32 , #dpas >
10451045 %38 = ttg.convert_layout %28 : tensor <32 x32 xf32 , #dpas > -> tensor <32 x32 xf32 , #blocked >
10461046
0 commit comments