@@ -16,10 +16,11 @@ Module Name:
1616--*/
1717
1818#include " sqnbitgemm.h"
19- #include " sqnbitgemm_q8_block.h"
2019
2120#include < cassert>
2221
22+ #include " sqnbitgemm_q8_block.h"
23+
2324namespace
2425{
2526
@@ -80,7 +81,7 @@ MlasIsSQNBitGemmAvailable(
8081 Dispatch->Q4BitBlkDequantBForSgemm_CompFp32 != nullptr ;
8182 }
8283 case SQNBitGemmVariant_BitWidth4_CompInt8: {
83- return Dispatch->SQ4BitGemmM1Kernel_CompInt8 != nullptr &&
84+ return Dispatch->SQ4BitGemmKernel_CompInt8 != nullptr &&
8485 Dispatch->QuantizeARow_CompInt8 != nullptr ;
8586 }
8687 default : {
@@ -372,15 +373,17 @@ SQ4BitGemm_CompFp32(
372373 if (bias) {
373374 AddBiasForGemm (bias, c_blk, RowsHandled, CountN, ldc);
374375 }
376+
375377 if (DataParams->PostProcessor != nullptr ) {
376378 DataParams->PostProcessor ->Process (
377- DataParams->C , RangeStartM + RangeCountM - RowsRemaining, RangeStartN,
379+ DataParams->C , RangeStartM + RangeCountM - RowsRemaining, RangeStartN + n ,
378380 RowsHandled, CountN, ldc
379381 );
380382 }
381383
382384 c_blk += ldc * RowsHandled;
383385 a_row += lda * RowsHandled;
386+
384387 RowsRemaining -= RowsHandled;
385388 }
386389 }
@@ -431,36 +434,6 @@ SQ4BitGemm_CompInt8(
431434
432435 const float * Bias = (DataParams->Bias == nullptr ) ? nullptr : DataParams->Bias + RangeStartN;
433436
434- if (RangeCountM == 1 ) {
435- size_t CountN;
436- for (size_t n = 0 ; n < RangeCountN; n += CountN) {
437- CountN = std::min (RangeCountN - n, size_t {128 });
438-
439- const std::byte* a_row = QuantA;
440- const std::byte* b_col = QuantBData + n * ldb;
441- const float * b_col_scale = QuantBScale + n * k_blks;
442- const std::byte* b_col_zp =
443- (QuantBZeroPoint == nullptr ) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes;
444- float * c_blk = C + n;
445- const float * bias = (Bias == nullptr ) ? nullptr : Bias + n;
446-
447- GetMlasPlatform ().SQNBitGemmDispatch ->SQ4BitGemmM1Kernel_CompInt8 (
448- BlkLen,
449- a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias
450- );
451-
452- if (DataParams->PostProcessor != nullptr ) {
453- DataParams->PostProcessor ->Process (
454- DataParams->C , RangeStartM, RangeStartN + n,
455- RangeCountM, CountN, ldc
456- );
457- }
458- }
459- return ;
460- }
461-
462- // This is a naive M > 1 implementation that repeatedly calls the M=1 kernel.
463- // TODO Replace it with an optimized implementation.
464437 size_t CountN;
465438 for (size_t n = 0 ; n < RangeCountN; n += CountN) {
466439 CountN = std::min (RangeCountN - n, size_t {128 });
@@ -473,21 +446,24 @@ SQ4BitGemm_CompInt8(
473446 float * c_blk = C + n;
474447 const float * bias = (Bias == nullptr ) ? nullptr : Bias + n;
475448
476- for (size_t m = 0 ; m < RangeCountM; ++m) {
477- GetMlasPlatform ().SQNBitGemmDispatch ->SQ4BitGemmM1Kernel_CompInt8 (
449+ size_t RowsRemaining = RangeCountM;
450+ while (RowsRemaining > 0 ) {
451+ const auto RowsHandled = GetMlasPlatform ().SQNBitGemmDispatch ->SQ4BitGemmKernel_CompInt8 (
478452 BlkLen,
479- a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias
453+ a_row, b_col, b_col_scale, b_col_zp, c_blk, RowsRemaining, CountN, K, k_blks, ldc , bias
480454 );
481455
482456 if (DataParams->PostProcessor != nullptr ) {
483457 DataParams->PostProcessor ->Process (
484- DataParams->C , RangeStartM, RangeStartN + n,
485- RangeCountM , CountN, ldc
458+ DataParams->C , RangeStartM + RangeCountM - RowsRemaining , RangeStartN + n,
459+ RowsHandled , CountN, ldc
486460 );
487461 }
488462
489- c_blk += ldc;
490- a_row += lda;
463+ c_blk += RowsHandled * ldc;
464+ a_row += RowsHandled * lda;
465+
466+ RowsRemaining -= RowsHandled;
491467 }
492468 }
493469}
0 commit comments