Skip to content

Commit 20cd339

Browse files
authored
[MLAS] AArch64 SQNBitGemm CompInt8 initial multi-row implementation (microsoft#21193)
Update AArch64 SQNBitGemm CompInt8 kernels to process matrix in tiles. E.g., computing the output in 2x2 tiles allows us to compute four elements of the output with one read of two rows of A and two columns of B. Also moved some code around as it was getting big for a single file.
1 parent 8749fa3 commit 20cd339

12 files changed

+2248
-1384
lines changed

cmake/onnxruntime_mlas.cmake

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,10 @@ function(setup_mlas_source_for_windows)
8282
${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp
8383
${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp
8484
${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp
85+
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.h
8586
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp
87+
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp
88+
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp
8689
)
8790

8891
set(mlas_platform_preprocess_srcs
@@ -350,9 +353,12 @@ else()
350353
${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp
351354
${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp
352355
${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp
356+
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.h
353357
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp
358+
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp
359+
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp
354360
)
355-
set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp
361+
set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp
356362
PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+dotprod")
357363
if (NOT APPLE)
358364
set(mlas_platform_srcs

onnxruntime/core/mlas/lib/sqnbitgemm.cpp

Lines changed: 16 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@ Module Name:
1616
--*/
1717

1818
#include "sqnbitgemm.h"
19-
#include "sqnbitgemm_q8_block.h"
2019

2120
#include <cassert>
2221

22+
#include "sqnbitgemm_q8_block.h"
23+
2324
namespace
2425
{
2526

@@ -80,7 +81,7 @@ MlasIsSQNBitGemmAvailable(
8081
Dispatch->Q4BitBlkDequantBForSgemm_CompFp32 != nullptr;
8182
}
8283
case SQNBitGemmVariant_BitWidth4_CompInt8: {
83-
return Dispatch->SQ4BitGemmM1Kernel_CompInt8 != nullptr &&
84+
return Dispatch->SQ4BitGemmKernel_CompInt8 != nullptr &&
8485
Dispatch->QuantizeARow_CompInt8 != nullptr;
8586
}
8687
default: {
@@ -372,15 +373,17 @@ SQ4BitGemm_CompFp32(
372373
if (bias) {
373374
AddBiasForGemm(bias, c_blk, RowsHandled, CountN, ldc);
374375
}
376+
375377
if (DataParams->PostProcessor != nullptr) {
376378
DataParams->PostProcessor->Process(
377-
DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN,
379+
DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN + n,
378380
RowsHandled, CountN, ldc
379381
);
380382
}
381383

382384
c_blk += ldc * RowsHandled;
383385
a_row += lda * RowsHandled;
386+
384387
RowsRemaining -= RowsHandled;
385388
}
386389
}
@@ -431,36 +434,6 @@ SQ4BitGemm_CompInt8(
431434

432435
const float* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias + RangeStartN;
433436

434-
if (RangeCountM == 1) {
435-
size_t CountN;
436-
for (size_t n = 0; n < RangeCountN; n += CountN) {
437-
CountN = std::min(RangeCountN - n, size_t{128});
438-
439-
const std::byte* a_row = QuantA;
440-
const std::byte* b_col = QuantBData + n * ldb;
441-
const float* b_col_scale = QuantBScale + n * k_blks;
442-
const std::byte* b_col_zp =
443-
(QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes;
444-
float* c_blk = C + n;
445-
const float* bias = (Bias == nullptr) ? nullptr : Bias + n;
446-
447-
GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmM1Kernel_CompInt8(
448-
BlkLen,
449-
a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias
450-
);
451-
452-
if (DataParams->PostProcessor != nullptr) {
453-
DataParams->PostProcessor->Process(
454-
DataParams->C, RangeStartM, RangeStartN + n,
455-
RangeCountM, CountN, ldc
456-
);
457-
}
458-
}
459-
return;
460-
}
461-
462-
// This is a naive M > 1 implementation that repeatedly calls the M=1 kernel.
463-
// TODO Replace it with an optimized implementation.
464437
size_t CountN;
465438
for (size_t n = 0; n < RangeCountN; n += CountN) {
466439
CountN = std::min(RangeCountN - n, size_t{128});
@@ -473,21 +446,24 @@ SQ4BitGemm_CompInt8(
473446
float* c_blk = C + n;
474447
const float* bias = (Bias == nullptr) ? nullptr : Bias + n;
475448

476-
for (size_t m = 0; m < RangeCountM; ++m) {
477-
GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmM1Kernel_CompInt8(
449+
size_t RowsRemaining = RangeCountM;
450+
while (RowsRemaining > 0) {
451+
const auto RowsHandled = GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_CompInt8(
478452
BlkLen,
479-
a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias
453+
a_row, b_col, b_col_scale, b_col_zp, c_blk, RowsRemaining, CountN, K, k_blks, ldc, bias
480454
);
481455

482456
if (DataParams->PostProcessor != nullptr) {
483457
DataParams->PostProcessor->Process(
484-
DataParams->C, RangeStartM, RangeStartN + n,
485-
RangeCountM, CountN, ldc
458+
DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN + n,
459+
RowsHandled, CountN, ldc
486460
);
487461
}
488462

489-
c_blk += ldc;
490-
a_row += lda;
463+
c_blk += RowsHandled * ldc;
464+
a_row += RowsHandled * lda;
465+
466+
RowsRemaining -= RowsHandled;
491467
}
492468
}
493469
}

onnxruntime/core/mlas/lib/sqnbitgemm.h

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,6 @@ struct MLAS_SQNBIT_GEMM_DISPATCH {
184184
/**
185185
* @brief Multiply quantized 8-bit integer matrix A with quantized 4-bit integer matrix B.
186186
* A and B are block quantized and B is column major.
187-
* This kernel handles the special case where M, the number of rows of A and C, is 1.
188187
*
189188
* @param BlkLen Number of values in a block.
190189
* @param QuantA Supplies the quantized A matrix.
@@ -193,25 +192,31 @@ struct MLAS_SQNBIT_GEMM_DISPATCH {
193192
* @param QuantBScale Supplies the quantized B matrix block scale values.
194193
* @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional.
195194
* @param[out] C Supplies the output C matrix.
196-
* @param CountN Number of columns of B and C.
195+
* @param CountM Number of rows of A and C to process, an upper bound.
196+
* @param CountN Number of columns of B and C to process.
197197
* @param CountK Number of columns of A and rows of B.
198-
* @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix.
198+
* @param BlockCountK Number of blocks in one row of A and one column of B.
199+
* @param ldc Number of elements between adjacent rows of C.
199200
* @param Bias Bias vector of length N.
201+
*
202+
* @return The number of rows of A and C that were processed, at most CountM.
200203
*/
201-
typedef void(SQ4BitGemmM1Kernel_CompInt8_Fn)(
204+
typedef size_t(SQ4BitGemmKernel_CompInt8_Fn)(
202205
size_t BlkLen,
203206
const std::byte* QuantA,
204207
const std::byte* QuantBData,
205208
const float* QuantBScale,
206209
const std::byte* QuantBZeroPoint,
207210
float* C,
211+
size_t CountM,
208212
size_t CountN,
209213
size_t CountK,
210-
size_t BlockStrideQuantB,
214+
size_t BlockCountK,
215+
size_t ldc,
211216
const float* Bias
212217
);
213218

214-
SQ4BitGemmM1Kernel_CompInt8_Fn* SQ4BitGemmM1Kernel_CompInt8 = nullptr;
219+
SQ4BitGemmKernel_CompInt8_Fn* SQ4BitGemmKernel_CompInt8 = nullptr;
215220

216221
/**
217222
* @brief Block quantize values from one row of matrix A from floats to quantized 8-bit integers.

onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,44 @@ SQ4BitGemmM1Kernel_CompInt8_avx2(
434434
}
435435
}
436436

437+
size_t
438+
SQ4BitGemmKernel_CompInt8_avx2(
439+
size_t BlkLen,
440+
const std::byte* QuantA,
441+
const std::byte* QuantBData,
442+
const float* QuantBScale,
443+
const std::byte* QuantBZeroPoint,
444+
float* C,
445+
size_t CountM,
446+
size_t CountN,
447+
size_t CountK,
448+
size_t BlockCountK,
449+
size_t ldc,
450+
const float* Bias
451+
)
452+
{
453+
MLAS_UNREFERENCED_PARAMETER(ldc);
454+
455+
if (CountM == 0) {
456+
return 0;
457+
}
458+
459+
SQ4BitGemmM1Kernel_CompInt8_avx2(
460+
BlkLen,
461+
QuantA,
462+
QuantBData,
463+
QuantBScale,
464+
QuantBZeroPoint,
465+
C,
466+
CountN,
467+
CountK,
468+
BlockCountK,
469+
Bias
470+
);
471+
472+
return 1;
473+
}
474+
437475
template <size_t NCols, bool HasZeroPoint>
438476
MLAS_FORCEINLINE void
439477
ComputeDotProducts_BlkLen16_CompFp32_avx2(
@@ -1109,7 +1147,7 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2 = []() {
11091147
d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx2;
11101148
d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2;
11111149

1112-
d.SQ4BitGemmM1Kernel_CompInt8 = SQ4BitGemmM1Kernel_CompInt8_avx2;
1150+
d.SQ4BitGemmKernel_CompInt8 = SQ4BitGemmKernel_CompInt8_avx2;
11131151
d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8_avx2;
11141152

11151153
return d;

onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512 = []() {
239239
d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx512;
240240
d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2;
241241

242-
d.SQ4BitGemmM1Kernel_CompInt8 = SQ4BitGemmM1Kernel_CompInt8_avx2;
242+
d.SQ4BitGemmKernel_CompInt8 = SQ4BitGemmKernel_CompInt8_avx2;
243243
d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8_avx512;
244244

245245
return d;

onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,44 @@ SQ4BitGemmM1Kernel_CompInt8_avx512vnni(
237237
}
238238
}
239239

240+
size_t
241+
SQ4BitGemmKernel_CompInt8_avx512vnni(
242+
size_t BlkLen,
243+
const std::byte* QuantA,
244+
const std::byte* QuantBData,
245+
const float* QuantBScale,
246+
const std::byte* QuantBZeroPoint,
247+
float* C,
248+
size_t CountM,
249+
size_t CountN,
250+
size_t CountK,
251+
size_t BlockCountK,
252+
size_t ldc,
253+
const float* Bias
254+
)
255+
{
256+
MLAS_UNREFERENCED_PARAMETER(ldc);
257+
258+
if (CountM == 0) {
259+
return 0;
260+
}
261+
262+
SQ4BitGemmM1Kernel_CompInt8_avx512vnni(
263+
BlkLen,
264+
QuantA,
265+
QuantBData,
266+
QuantBScale,
267+
QuantBZeroPoint,
268+
C,
269+
CountN,
270+
CountK,
271+
BlockCountK,
272+
Bias
273+
);
274+
275+
return 1;
276+
}
277+
240278
void MLASCALL
241279
MlasQ80BlkQuantRow_avx512(
242280
size_t BlkLen,
@@ -260,7 +298,7 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni = []() {
260298
d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32;
261299
d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2;
262300

263-
d.SQ4BitGemmM1Kernel_CompInt8 = SQ4BitGemmM1Kernel_CompInt8_avx512vnni;
301+
d.SQ4BitGemmKernel_CompInt8 = SQ4BitGemmKernel_CompInt8_avx512vnni;
264302
d.QuantizeARow_CompInt8 = MlasQ80BlkQuantRow_avx512;
265303

266304
return d;

onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -158,17 +158,19 @@ Q4BitBlkDequantBForSgemm_CompFp32_avx2(
158158
const size_t BlockStrideQuantB
159159
);
160160

161-
void
162-
SQ4BitGemmM1Kernel_CompInt8_avx2(
161+
size_t
162+
SQ4BitGemmKernel_CompInt8_avx2(
163163
size_t BlkLen,
164164
const std::byte* QuantA,
165165
const std::byte* QuantBData,
166166
const float* QuantBScale,
167167
const std::byte* QuantBZeroPoint,
168168
float* C,
169+
size_t CountM,
169170
size_t CountN,
170171
size_t CountK,
171-
size_t BlockStrideQuantB,
172+
size_t BlockCountK,
173+
size_t ldc,
172174
const float* Bias
173175
);
174176

0 commit comments

Comments
 (0)