@@ -321,13 +321,28 @@ static inline JBLAS_CODE decompress_kblock_bit4_packrow1(utils::bit4x2* srcptr,
321321 vzps[iv] = _mm512_cvtepi8_epi32 (tmp);
322322 }
323323 }
324- }
325- for (; irow < row; irow++) {
326- pad_bit4 (tmpbuf, reinterpret_cast <int8_t *>(srcptr + irow * ld_src / 2 ), zmm_mask, LoadMask48);
327- if constexpr (_IS_SYM) {
328- dequantize (dstptr + irow * ld_dst, tmpbuf, vscales, nullptr );
329- } else {
330- dequantize (dstptr + irow * ld_dst, tmpbuf, vscales, vzps);
324+ auto rowre = row - irow;
325+ int rowpad4 = utils::padto_le (rowre, UnrollRow) + irow;
326+ for (; irow < rowpad4; irow += UnrollRow) {
327+ for (int iter64 = 0 ; iter64 < Loop64; iter64++) {
328+ pad_bit4 (tmpbuf + iter64 * 64 , reinterpret_cast <int8_t *>(srcptr + irow * ld_src / 2 + 32 * iter64), zmm_mask,
329+ LoadMask64);
330+ }
331+ for (int iterr = 0 ; iterr < UnrollRow; iterr++) {
332+ if constexpr (_IS_SYM) {
333+ dequantize (dstptr + (irow + iterr) * ld_dst, tmpbuf + iterr * ColTile, vscales, nullptr );
334+ } else {
335+ dequantize (dstptr + (irow + iterr) * ld_dst, tmpbuf + iterr * ColTile, vscales, vzps);
336+ }
337+ }
338+ }
339+ for (; irow < row; irow++) {
340+ pad_bit4 (tmpbuf, reinterpret_cast <int8_t *>(srcptr + irow * ld_src / 2 ), zmm_mask, LoadMask48);
341+ if constexpr (_IS_SYM) {
342+ dequantize (dstptr + irow * ld_dst, tmpbuf, vscales, nullptr );
343+ } else {
344+ dequantize (dstptr + irow * ld_dst, tmpbuf, vscales, vzps);
345+ }
331346 }
332347 }
333348 return JblasSuccess;
@@ -563,9 +578,8 @@ inline JBLAS_CODE decompress_kblock_f8_fp(utils::f8* srcptr, _DST_T* dstptr, int
563578 auto sptr = scales + kpos * NPad;
564579 int j = 0 ;
565580 auto quant = [&](__mmask16 mask) {
566- __m128i f8_src;
567581 auto sign_revert =
568- _mm512_cvtepi8_epi32 (_mm_mask_loadu_epi8 (f8_src, mask, reinterpret_cast <__m128i*>(srcptr + i * ld_src + j)));
582+ _mm512_cvtepi8_epi32 (_mm_maskz_loadu_epi8 ( mask, reinterpret_cast <__m128i*>(srcptr + i * ld_src + j)));
569583 auto e_revert = sign_revert;
570584 auto mantissa_revert = sign_revert;
571585 sign_revert = _mm512_slli_epi32 (sign_revert, 24 );
@@ -888,10 +902,10 @@ inline void f32_f4_quantize_4x16(const float* srcptr, int8_t* dstptr, int ld_src
888902 zmm2 = _mm512_add_ps (zmm2, zmm_zp);
889903 zmm3 = _mm512_add_ps (zmm3, zmm_zp);
890904 } else {
891- mask4 = _mm512_cmplt_ps_mask (zmm0, zmm_v0);
892- mask5 = _mm512_cmplt_ps_mask (zmm1, zmm_v0);
893- mask6 = _mm512_cmplt_ps_mask (zmm2, zmm_v0);
894- mask7 = _mm512_cmplt_ps_mask (zmm3, zmm_v0);
905+ mask4 = _mm512_cmp_ps_mask (zmm0, zmm_v0, 1 );
906+ mask5 = _mm512_cmp_ps_mask (zmm1, zmm_v0, 1 );
907+ mask6 = _mm512_cmp_ps_mask (zmm2, zmm_v0, 1 );
908+ mask7 = _mm512_cmp_ps_mask (zmm3, zmm_v0, 1 );
895909
896910 zmm0 = _mm512_abs_ps (zmm0);
897911 zmm1 = _mm512_abs_ps (zmm1);
@@ -908,10 +922,10 @@ inline void f32_f4_quantize_4x16(const float* srcptr, int8_t* dstptr, int ld_src
908922 zmm5 = _mm512_sub_ps (zmm1, sub_v);
909923 zmm6 = _mm512_sub_ps (zmm2, sub_v);
910924 zmm7 = _mm512_sub_ps (zmm3, sub_v);
911- mask0 = _mm512_cmple_ps_mask (zmm4, zmm_v0);
912- mask1 = _mm512_cmple_ps_mask (zmm5, zmm_v0);
913- mask2 = _mm512_cmple_ps_mask (zmm6, zmm_v0);
914- mask3 = _mm512_cmple_ps_mask (zmm7, zmm_v0);
925+ mask0 = _mm512_cmp_ps_mask (zmm4, zmm_v0, 2 );
926+ mask1 = _mm512_cmp_ps_mask (zmm5, zmm_v0, 2 );
927+ mask2 = _mm512_cmp_ps_mask (zmm6, zmm_v0, 2 );
928+ mask3 = _mm512_cmp_ps_mask (zmm7, zmm_v0, 2 );
915929 xmm0 = _mm_mask_blend_epi8 (mask0, xmm0, _mm_loadu_si128 (reinterpret_cast <const __m128i*>(broadcast_f4_v + i * 16 )));
916930 xmm1 = _mm_mask_blend_epi8 (mask1, xmm1, _mm_loadu_si128 (reinterpret_cast <const __m128i*>(broadcast_f4_v + i * 16 )));
917931 xmm2 = _mm_mask_blend_epi8 (mask2, xmm2, _mm_loadu_si128 (reinterpret_cast <const __m128i*>(broadcast_f4_v + i * 16 )));
@@ -949,7 +963,7 @@ inline void f32_f4_quantize_1x16(const float* srcptr, int8_t* dstptr, int ld_src
949963 auto zp = _mm512_set1_ps (0 .8480964004993439f );
950964 zmm0 = _mm512_add_ps (zmm0, zp);
951965 } else {
952- mask1 = _mm512_cmplt_ps_mask (zmm0, zmm_v0);
966+ mask1 = _mm512_cmp_ps_mask (zmm0, zmm_v0, 1 );
953967 zmm0 = _mm512_abs_ps (zmm0);
954968 }
955969 constexpr int loop_num = F4_T == JBLAS_DTYPE::F4_NF4 ? 16 : 8 ;
@@ -959,7 +973,7 @@ inline void f32_f4_quantize_1x16(const float* srcptr, int8_t* dstptr, int ld_src
959973 if constexpr (F4_T == JBLAS_DTYPE::F4_BNB) sub_v = _mm512_set1_ps (F4_BNB_quant_sub_helper[i]);
960974 if constexpr (F4_T == JBLAS_DTYPE::F4_E2M1) sub_v = _mm512_set1_ps (F4_E2M1_quant_sub_helper[i]);
961975 zmm1 = _mm512_sub_ps (zmm0, sub_v);
962- mask0 = _mm512_cmple_ps_mask (zmm1, zmm_v0);
976+ mask0 = _mm512_cmp_ps_mask (zmm1, zmm_v0, 2 );
963977 xmm0 = _mm_mask_blend_epi8 (mask0, xmm0, _mm_loadu_si128 (reinterpret_cast <const __m128i*>(broadcast_f4_v + i * 16 )));
964978 zmm0 = _mm512_mask_add_ps (zmm0, mask0, zmm0, avoid_double_cmp);
965979 }
0 commit comments