Skip to content

Commit 8267066

Browse files
authored
Add LSX support for S8S8 and S8U8 GEMM kernels (microsoft#24397)
### Description - Add missing support for S8S8/S8U8 in GEMM kernels of LSX - Add new dispatch entries for S8S8 and S8U8 GEMM operations in mlasi.h - Extend MLAS_PLATFORM struct to include S8S8 and S8U8 dispatch pointers for LSX ### Motivation and Context To fix [build error](lcpu-club/loongarch-packages#526 (comment)) on loong64: ``` error: ‘struct MLAS_PLATFORM’ has no member named ‘GemmS8S8Dispatch’ ``` ### Test status Tested on Arch Linux for Loong64, here is the build log: * [onnxruntime-1.20.2-7.1-loong64-build.log](https://github.com/user-attachments/files/19710083/onnxruntime-1.20.2-7.1-loong64-build.log) Signed-off-by: Zhou Qiankang <wszqkzqk@qq.com>
1 parent f20df72 commit 8267066

File tree

4 files changed

+399
-1
lines changed

4 files changed

+399
-1
lines changed

onnxruntime/core/mlas/lib/mlasi.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1153,6 +1153,8 @@ struct MLAS_GEMM_QUANT_DISPATCH;
11531153

11541154
extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchSse;
11551155
extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchLSX;
1156+
extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmS8S8DispatchLSX;
1157+
extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmS8U8DispatchLSX;
11561158
extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8S8DispatchSse41;
11571159
extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8S8DispatchAvx2;
11581160
extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8U8DispatchAvx2;
@@ -1337,6 +1339,8 @@ struct MLAS_PLATFORM {
13371339
#if defined(MLAS_TARGET_LARCH64)
13381340
const MLAS_GEMM_QUANT_DISPATCH* GemmU8S8Dispatch;
13391341
const MLAS_GEMM_QUANT_DISPATCH* GemmU8U8Dispatch;
1342+
const MLAS_GEMM_QUANT_DISPATCH* GemmS8S8Dispatch;
1343+
const MLAS_GEMM_QUANT_DISPATCH* GemmS8U8Dispatch;
13401344
MLAS_GEMM_FLOAT_KERNEL* GemmFloatKernel;
13411345
MLAS_GEMM_DOUBLE_KERNEL* GemmDoubleKernel;
13421346
MLAS_CONV_FLOAT_KERNEL* ConvNchwFloatKernel;

onnxruntime/core/mlas/lib/platform.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -747,10 +747,14 @@ Return Value:
747747

748748
this->GemmU8S8Dispatch = &MlasGemmU8X8DispatchLSX;
749749
this->GemmU8U8Dispatch = &MlasGemmU8X8DispatchLSX;
750+
this->GemmS8S8Dispatch = &MlasGemmS8S8DispatchLSX;
751+
this->GemmS8U8Dispatch = &MlasGemmS8U8DispatchLSX;
750752
}else if( cap_lsx ){
751753
this->GemmFloatKernel = MlasGemmFloatKernelLSX;
752754
this->GemmU8S8Dispatch = &MlasGemmU8X8DispatchLSX;
753755
this->GemmU8U8Dispatch = &MlasGemmU8X8DispatchLSX;
756+
this->GemmS8S8Dispatch = &MlasGemmS8S8DispatchLSX;
757+
this->GemmS8U8Dispatch = &MlasGemmS8U8DispatchLSX;
754758
this->TransposePackB16x4Routine = MlasSgemmTransposePackB16x4LSX;
755759
this->GemmDoubleKernel = MlasGemmDoubleKernelLSX;
756760
this->ConvNchwFloatKernel = MlasConvNchwFloatKernelLSX;

onnxruntime/core/mlas/lib/qgemm.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -905,7 +905,10 @@ MlasGemmQuantGetDispatch(
905905
GemmQuantDispatch = GetMlasPlatform().GemmU8X8Dispatch;
906906
}
907907
#elif defined(MLAS_TARGET_LARCH64)
908-
if (!AIsSigned) {
908+
if (AIsSigned) {
909+
GemmQuantDispatch =
910+
BIsSigned ? GetMlasPlatform().GemmS8S8Dispatch : GetMlasPlatform().GemmS8U8Dispatch;
911+
} else { // !AIsSigned
909912
GemmQuantDispatch =
910913
BIsSigned ? GetMlasPlatform().GemmU8S8Dispatch : GetMlasPlatform().GemmU8U8Dispatch;
911914
}

0 commit comments

Comments
 (0)