Skip to content
This repository was archived by the owner on Oct 25, 2024. It is now read-only.

Commit 47d8755

Browse files
authored
update jblas' bug fix (#1057)
1 parent 1ed3862 commit 47d8755

File tree

4 files changed

+97
-46
lines changed

4 files changed

+97
-46
lines changed

intel_extension_for_transformers/llm/library/jblas/jblas/jit_blas_parallel.h

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ class SchedulerBase : public Scheduler2D {
204204
mL2Use += static_cast<size_t>(mBlock[1]) * mBlock[2] * mEleSize[1];
205205
mL2Use += static_cast<size_t>(mStep[0]) * mBlock[2] * mEleSize[0];
206206
}
207-
const float DensityThres = 32;
207+
const float DensityThres = 16;
208208
static size_t constexpr ReservedSize = 32ULL * 1024ULL;
209209

210210
virtual float calculate_score() {
@@ -364,7 +364,7 @@ class SchedulerKBlock : public Scheduler2D {
364364
mL2Use += static_cast<size_t>(mBlock[1]) * mBlock[2] * mEleSize[1];
365365
mL2Use += static_cast<size_t>(mStep[0]) * mBlock[2] * mEleSize[0];
366366
}
367-
const float DensityThres = 32;
367+
const float DensityThres = 16;
368368

369369
float calculate_score() {
370370
int tmpnstep = mThdSize[1] < _GemmCore_T::PREFERRED_N ? mThdSize[1] : _GemmCore_T::PREFERRED_N;
@@ -492,10 +492,11 @@ class SchedulerKBlockS : public SchedulerBase<_GemmCore_T> {
492492
assert(this->mBlock[0]>0);
493493
assert(this->mBlock[1]>0);
494494
assert(this->mBlock[2]>0);
495+
assert(this->mBlock[2] % _GemmCore_T::KTILE == 0);
495496
}
496497

497498
protected:
498-
const float DensityThres = 32;
499+
const float DensityThres = 16;
499500
static size_t constexpr ReservedSize = 32ULL * 1024ULL;
500501

501502
void cache_blocking_compute() override {
@@ -529,6 +530,11 @@ class SchedulerKBlockS : public SchedulerBase<_GemmCore_T> {
529530
(this->mStep[0] * this->mEleSize[0] +
530531
float(CorSize * (this->mStep[0] + this->mBlock[1])) / this->mKBlock +
531532
this->mBlock[1] * this->mEleSize[1]));
533+
if (rawk < this->mKBlock) {
534+
rawk = static_cast<int>((valid_total - this->mBlock[0] * this->mBlock[1] * this->mEleSize[2] -
535+
1 * CorSize * (this->mStep[0] + this->mBlock[1])) /
536+
(this->mStep[0] * this->mEleSize[0] + this->mBlock[1] * this->mEleSize[1]));
537+
}
532538
rawk = std::min(rawk, this->mSizePadded[2]);
533539
this->mBlock[2] = utils::padto_le(rawk, this->mStep[2]);
534540
if (this->mBlock[2] > this->mKBlock) {
@@ -569,9 +575,6 @@ class SchedulerKBlockS : public SchedulerBase<_GemmCore_T> {
569575
this->mBlock[2] = static_cast<int>(getMaxK(this->mBlock[1]));
570576
this->mBlock[2] = utils::padto_le(this->mBlock[2], this->mStep[2]);
571577
this->mBlock[2] = std::min(mKBlock, this->mBlock[2]);
572-
auto tmp = utils::updiv(mKBlock, this->mBlock[2]);
573-
while (mKBlock % tmp != 0) tmp++; // TODO(Yu) optimize
574-
this->mBlock[2] = utils::downdiv(mKBlock, tmp);
575578
}
576579
}
577580

intel_extension_for_transformers/llm/library/jblas/jblas/kernel_avx2.h

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -412,10 +412,14 @@ inline JBLAS_CODE decompress_kblock_f8_fp(utils::f8* srcptr, _DST_T* dstptr, int
412412
for (; j < align_col; j += 8) quant();
413413
for (; j < col; j++) {
414414
auto fp_v = ref::f8_to_fp32(srcptr[i * ld_src + j], src_f8_type);
415-
if constexpr (std::is_same_v<_S_T, utils::f8>) {
416-
dstptr[i * ld_dst + j] = fp_v * std::pow(2, sptr[j / _PACK_ROW].x);
417-
} else if constexpr (std::is_same_v<_S_T, float>) {
418-
dstptr[i * ld_dst + j] = fp_v * sptr[j / _PACK_ROW];
415+
if constexpr (WITH_SCALE) {
416+
if constexpr (std::is_same_v<_S_T, utils::f8>) {
417+
dstptr[i * ld_dst + j] = fp_v * std::pow(2, sptr[j / _PACK_ROW].x);
418+
} else if constexpr (std::is_same_v<_S_T, float>) {
419+
dstptr[i * ld_dst + j] = fp_v * sptr[j / _PACK_ROW];
420+
}
421+
} else {
422+
dstptr[i * ld_dst + j] = fp_v;
419423
}
420424
}
421425
}
@@ -636,6 +640,14 @@ static inline JBLAS_CODE decompress_kblock_bit4_packrow1(
636640
vzps[iv] = _mm256_cvtepi8_epi32(tmp);
637641
}
638642
}
643+
auto rowre = row - irow;
644+
int rowpad4 = utils::padto_le(rowre, UnrollRow) + irow;
645+
for (; irow < rowpad4; irow += UnrollRow) {
646+
for (int iter16 = 0; iter16 < Loop16; iter16++)
647+
pad_bit4_16(tmpbuf + iter16 * 16, reinterpret_cast<int8_t*>(srcptr + irow * ld_src / 2 + 8 * iter16));
648+
for (int iterr = 0; iterr < UnrollRow; iterr++)
649+
dequantize(dstptr + (irow + iterr) * ld_dst, tmpbuf + iterr * _NCOL, vscales, vzps);
650+
}
639651
for (; irow < row; irow++) {
640652
if constexpr (_NCOL == 24) {
641653
pad_bit4_16(tmpbuf, reinterpret_cast<int8_t*>(srcptr + irow * ld_src / 2));

intel_extension_for_transformers/llm/library/jblas/jblas/kernel_avx512f.h

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -321,13 +321,28 @@ static inline JBLAS_CODE decompress_kblock_bit4_packrow1(utils::bit4x2* srcptr,
321321
vzps[iv] = _mm512_cvtepi8_epi32(tmp);
322322
}
323323
}
324-
}
325-
for (; irow < row; irow++) {
326-
pad_bit4(tmpbuf, reinterpret_cast<int8_t*>(srcptr + irow * ld_src / 2), zmm_mask, LoadMask48);
327-
if constexpr (_IS_SYM) {
328-
dequantize(dstptr + irow * ld_dst, tmpbuf, vscales, nullptr);
329-
} else {
330-
dequantize(dstptr + irow * ld_dst, tmpbuf, vscales, vzps);
324+
auto rowre = row - irow;
325+
int rowpad4 = utils::padto_le(rowre, UnrollRow) + irow;
326+
for (; irow < rowpad4; irow += UnrollRow) {
327+
for (int iter64 = 0; iter64 < Loop64; iter64++) {
328+
pad_bit4(tmpbuf + iter64 * 64, reinterpret_cast<int8_t*>(srcptr + irow * ld_src / 2 + 32 * iter64), zmm_mask,
329+
LoadMask64);
330+
}
331+
for (int iterr = 0; iterr < UnrollRow; iterr++) {
332+
if constexpr (_IS_SYM) {
333+
dequantize(dstptr + (irow + iterr) * ld_dst, tmpbuf + iterr * ColTile, vscales, nullptr);
334+
} else {
335+
dequantize(dstptr + (irow + iterr) * ld_dst, tmpbuf + iterr * ColTile, vscales, vzps);
336+
}
337+
}
338+
}
339+
for (; irow < row; irow++) {
340+
pad_bit4(tmpbuf, reinterpret_cast<int8_t*>(srcptr + irow * ld_src / 2), zmm_mask, LoadMask48);
341+
if constexpr (_IS_SYM) {
342+
dequantize(dstptr + irow * ld_dst, tmpbuf, vscales, nullptr);
343+
} else {
344+
dequantize(dstptr + irow * ld_dst, tmpbuf, vscales, vzps);
345+
}
331346
}
332347
}
333348
return JblasSuccess;
@@ -563,9 +578,8 @@ inline JBLAS_CODE decompress_kblock_f8_fp(utils::f8* srcptr, _DST_T* dstptr, int
563578
auto sptr = scales + kpos * NPad;
564579
int j = 0;
565580
auto quant = [&](__mmask16 mask) {
566-
__m128i f8_src;
567581
auto sign_revert =
568-
_mm512_cvtepi8_epi32(_mm_mask_loadu_epi8(f8_src, mask, reinterpret_cast<__m128i*>(srcptr + i * ld_src + j)));
582+
_mm512_cvtepi8_epi32(_mm_maskz_loadu_epi8(mask, reinterpret_cast<__m128i*>(srcptr + i * ld_src + j)));
569583
auto e_revert = sign_revert;
570584
auto mantissa_revert = sign_revert;
571585
sign_revert = _mm512_slli_epi32(sign_revert, 24);
@@ -888,10 +902,10 @@ inline void f32_f4_quantize_4x16(const float* srcptr, int8_t* dstptr, int ld_src
888902
zmm2 = _mm512_add_ps(zmm2, zmm_zp);
889903
zmm3 = _mm512_add_ps(zmm3, zmm_zp);
890904
} else {
891-
mask4 = _mm512_cmplt_ps_mask(zmm0, zmm_v0);
892-
mask5 = _mm512_cmplt_ps_mask(zmm1, zmm_v0);
893-
mask6 = _mm512_cmplt_ps_mask(zmm2, zmm_v0);
894-
mask7 = _mm512_cmplt_ps_mask(zmm3, zmm_v0);
905+
mask4 = _mm512_cmp_ps_mask(zmm0, zmm_v0, 1);
906+
mask5 = _mm512_cmp_ps_mask(zmm1, zmm_v0, 1);
907+
mask6 = _mm512_cmp_ps_mask(zmm2, zmm_v0, 1);
908+
mask7 = _mm512_cmp_ps_mask(zmm3, zmm_v0, 1);
895909

896910
zmm0 = _mm512_abs_ps(zmm0);
897911
zmm1 = _mm512_abs_ps(zmm1);
@@ -908,10 +922,10 @@ inline void f32_f4_quantize_4x16(const float* srcptr, int8_t* dstptr, int ld_src
908922
zmm5 = _mm512_sub_ps(zmm1, sub_v);
909923
zmm6 = _mm512_sub_ps(zmm2, sub_v);
910924
zmm7 = _mm512_sub_ps(zmm3, sub_v);
911-
mask0 = _mm512_cmple_ps_mask(zmm4, zmm_v0);
912-
mask1 = _mm512_cmple_ps_mask(zmm5, zmm_v0);
913-
mask2 = _mm512_cmple_ps_mask(zmm6, zmm_v0);
914-
mask3 = _mm512_cmple_ps_mask(zmm7, zmm_v0);
925+
mask0 = _mm512_cmp_ps_mask(zmm4, zmm_v0, 2);
926+
mask1 = _mm512_cmp_ps_mask(zmm5, zmm_v0, 2);
927+
mask2 = _mm512_cmp_ps_mask(zmm6, zmm_v0, 2);
928+
mask3 = _mm512_cmp_ps_mask(zmm7, zmm_v0, 2);
915929
xmm0 = _mm_mask_blend_epi8(mask0, xmm0, _mm_loadu_si128(reinterpret_cast<const __m128i*>(broadcast_f4_v + i * 16)));
916930
xmm1 = _mm_mask_blend_epi8(mask1, xmm1, _mm_loadu_si128(reinterpret_cast<const __m128i*>(broadcast_f4_v + i * 16)));
917931
xmm2 = _mm_mask_blend_epi8(mask2, xmm2, _mm_loadu_si128(reinterpret_cast<const __m128i*>(broadcast_f4_v + i * 16)));
@@ -949,7 +963,7 @@ inline void f32_f4_quantize_1x16(const float* srcptr, int8_t* dstptr, int ld_src
949963
auto zp = _mm512_set1_ps(0.8480964004993439f);
950964
zmm0 = _mm512_add_ps(zmm0, zp);
951965
} else {
952-
mask1 = _mm512_cmplt_ps_mask(zmm0, zmm_v0);
966+
mask1 = _mm512_cmp_ps_mask(zmm0, zmm_v0, 1);
953967
zmm0 = _mm512_abs_ps(zmm0);
954968
}
955969
constexpr int loop_num = F4_T == JBLAS_DTYPE::F4_NF4 ? 16 : 8;
@@ -959,7 +973,7 @@ inline void f32_f4_quantize_1x16(const float* srcptr, int8_t* dstptr, int ld_src
959973
if constexpr (F4_T == JBLAS_DTYPE::F4_BNB) sub_v = _mm512_set1_ps(F4_BNB_quant_sub_helper[i]);
960974
if constexpr (F4_T == JBLAS_DTYPE::F4_E2M1) sub_v = _mm512_set1_ps(F4_E2M1_quant_sub_helper[i]);
961975
zmm1 = _mm512_sub_ps(zmm0, sub_v);
962-
mask0 = _mm512_cmple_ps_mask(zmm1, zmm_v0);
976+
mask0 = _mm512_cmp_ps_mask(zmm1, zmm_v0, 2);
963977
xmm0 = _mm_mask_blend_epi8(mask0, xmm0, _mm_loadu_si128(reinterpret_cast<const __m128i*>(broadcast_f4_v + i * 16)));
964978
zmm0 = _mm512_mask_add_ps(zmm0, mask0, zmm0, avoid_double_cmp);
965979
}

intel_extension_for_transformers/llm/library/jblas/jblas/kernel_ref.h

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -230,25 +230,47 @@ inline void convert_s4_s8_8(int8_t* dstptr, int8_t* srcptr) {
230230
dstptr[7] = tmp;
231231
}
232232

233+
inline void convert_s4_s8_8_lowbits(int8_t* dstptr, int8_t* srcptr) {
234+
auto src32 = *reinterpret_cast<uint32_t*>(srcptr);
235+
auto tmp = static_cast<int>(src32 & 0xf);
236+
dstptr[0] = static_cast<int8_t>(tmp);
237+
tmp = static_cast<int>(src32 & 0xf0) >> 4;
238+
dstptr[1] = static_cast<int8_t>(tmp);
239+
tmp = static_cast<int>((src32 & 0xf00) >> 8);
240+
dstptr[2] = static_cast<int8_t>(tmp);
241+
tmp = static_cast<int>((src32 & 0xf000) >> 12);
242+
dstptr[3] = static_cast<int8_t>(tmp);
243+
tmp = static_cast<int>((src32 & 0xf0000) >> 16);
244+
dstptr[4] = static_cast<int8_t>(tmp);
245+
tmp = static_cast<int>((src32 & 0xf00000) >> 20);
246+
dstptr[5] = static_cast<int8_t>(tmp);
247+
tmp = static_cast<int>((src32 & 0xf000000) >> 24);
248+
dstptr[6] = static_cast<int8_t>(tmp);
249+
tmp = static_cast<int>((src32 & 0xf0000000) >> 28);
250+
dstptr[7] = static_cast<int8_t>(tmp);
251+
}
252+
233253
template <>
234254
inline void convert_s4_s8_8<JBLAS_DTYPE::S4_FULLRANGE>(int8_t* dstptr, int8_t* srcptr) {
235-
auto src32 = *reinterpret_cast<uint32_t*>(srcptr);
236-
auto tmp = static_cast<int8_t>(src32 & 0xf);
237-
dstptr[0] = tmp - 8;
238-
tmp = static_cast<int8_t>(src32 & 0xf0) >> 4;
239-
dstptr[1] = tmp - 8;
240-
tmp = static_cast<int8_t>((src32 & 0xf00) >> 8);
241-
dstptr[2] = tmp - 8;
242-
tmp = static_cast<int8_t>((src32 & 0xf000) >> 12);
243-
dstptr[3] = tmp - 8;
244-
tmp = static_cast<int8_t>((src32 & 0xf0000) >> 16);
245-
dstptr[4] = tmp - 8;
246-
tmp = static_cast<int8_t>((src32 & 0xf00000) >> 20);
247-
dstptr[5] = tmp - 8;
248-
tmp = static_cast<int8_t>((src32 & 0xf000000) >> 24);
249-
dstptr[6] = tmp - 8;
250-
tmp = static_cast<int8_t>((src32 & 0xf0000000) >> 28);
251-
dstptr[7] = tmp - 8;
255+
convert_s4_s8_8_lowbits(dstptr, srcptr);
256+
for (size_t i = 0; i < 8; i++) {
257+
dstptr[i] -= 8;
258+
}
259+
}
260+
261+
template <>
262+
inline void convert_s4_s8_8<JBLAS_DTYPE::F4_BNB>(int8_t* dstptr, int8_t* srcptr) {
263+
convert_s4_s8_8_lowbits(dstptr, srcptr);
264+
}
265+
266+
template <>
267+
inline void convert_s4_s8_8<JBLAS_DTYPE::F4_NF4>(int8_t* dstptr, int8_t* srcptr) {
268+
convert_s4_s8_8_lowbits(dstptr, srcptr);
269+
}
270+
271+
template <>
272+
inline void convert_s4_s8_8<JBLAS_DTYPE::F4_E2M1>(int8_t* dstptr, int8_t* srcptr) {
273+
convert_s4_s8_8_lowbits(dstptr, srcptr);
252274
}
253275

254276
template <JBLAS_DTYPE S4_T>

0 commit comments

Comments
 (0)