@@ -20,6 +20,20 @@ struct mma_int_A_I16K4 {
2020 GGML_CUDA_ASSUME (ret < K);
2121 return ret;
2222 }
23+
24+ __device__ __forceinline__ void load (const int * __restrict__ xs0, const int & stride) {
25+ #if defined(INT8_MMA_AVAILABLE)
26+ const int * xs = xs0 + (threadIdx .x %I)*stride + (threadIdx .x /I)*(K/2 );
27+ asm (" ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
28+ : " +r" (x[0 ]), " +r" (x[1 ])
29+ : " l" (xs));
30+ #else
31+ #pragma unroll
32+ for (int l = 0 ; l < ne; ++l) {
33+ x[l] = xs0[get_i (l)*stride + get_k (l)];
34+ }
35+ #endif // defined(INT8_MMA_AVAILABLE)
36+ }
2337};
2438
2539struct mma_int_A_I16K8 {
@@ -42,6 +56,20 @@ struct mma_int_A_I16K8 {
4256 GGML_CUDA_ASSUME (ret < K);
4357 return ret;
4458 }
59+
60+ __device__ __forceinline__ void load (const int * __restrict__ xs0, const int & stride) {
61+ #if defined(INT8_MMA_AVAILABLE)
62+ const int * xs = xs0 + (threadIdx .x %I)*stride + (threadIdx .x /I)*(K/2 );
63+ asm (" ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
64+ : " +r" (x[0 ]), " +r" (x[1 ]), " +r" (x[2 ]), " +r" (x[3 ])
65+ : " l" (xs));
66+ #else
67+ #pragma unroll
68+ for (int l = 0 ; l < ne; ++l) {
69+ x[l] = xs0[get_i (l)*stride + get_k (l)];
70+ }
71+ #endif // defined(INT8_MMA_AVAILABLE)
72+ }
4573};
4674
4775struct mma_int_B_J8K4 {
@@ -64,6 +92,20 @@ struct mma_int_B_J8K4 {
6492 GGML_CUDA_ASSUME (ret < K);
6593 return ret;
6694 }
95+
96+ __device__ __forceinline__ void load (const int * __restrict__ xs0, const int & stride) {
97+ #if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster
98+ const int * xs = xs0 + (threadIdx .x %J)*stride;
99+ asm (" ldmatrix.sync.aligned.m8n8.x1.b16 {%0}, [%1];"
100+ : " +r" (x[0 ])
101+ : " l" (xs));
102+ #else
103+ #pragma unroll
104+ for (int l = 0 ; l < ne; ++l) {
105+ x[l] = xs0[get_j (l)*stride + get_k (l)];
106+ }
107+ #endif // defined(INT8_MMA_AVAILABLE)
108+ }
67109};
68110
69111struct mma_int_B_J8K8 {
@@ -86,6 +128,20 @@ struct mma_int_B_J8K8 {
86128 GGML_CUDA_ASSUME (ret < K);
87129 return ret;
88130 }
131+
132+ __device__ __forceinline__ void load (const int * __restrict__ xs0, const int & stride) {
133+ #if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster
134+ const int * xs = xs0 + (threadIdx .x %J)*stride + ((threadIdx .x /J)*(K/2 )) % K;
135+ asm (" ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
136+ : " +r" (x[0 ]), " +r" (x[1 ])
137+ : " l" (xs));
138+ #else
139+ #pragma unroll
140+ for (int l = 0 ; l < ne; ++l) {
141+ x[l] = xs0[get_j (l)*stride + get_k (l)];
142+ }
143+ #endif // defined(INT8_MMA_AVAILABLE)
144+ }
89145};
90146
91147struct mma_int_C_I16J8 {
0 commit comments