@@ -42,17 +42,25 @@ infiniStatus_t Descriptor::create(infiniopHandle_t handle_, Descriptor **desc_pt
4242 return INFINI_STATUS_SUCCESS;
4343}
4444
45- float quantize (float x, float s, float z, float maxq) {
45+ float quantize (float x, float s, float z, float minq, float maxq) {
4646 float q = std::roundf (x / s + z);
47- q = std::max (0 . 0f , std::min (maxq, q));
47+ q = std::max (minq , std::min (maxq, q));
4848 return s * (q - z);
4949}
5050
5151template <typename T>
5252void find_params (T *x, T *b_scale, T *zero, int N, int K,
5353 int bits = 4 , bool sym = false , bool mse = false ,
54- float norm = 2 .4f , int grid = 100 , float maxshrink = 0 .8f ) {
55- float maxq = static_cast <float >(std::pow (2 , bits) - 1 );
54+ float norm = 2 .4f , int grid = 100 , float maxshrink = 0 .8f , bool sign_ed = false ) {
55+ float maxq;
56+ float minq;
57+ if (sign_ed) { // 如果有符号量化
58+ maxq = static_cast <float >(std::pow (2 , bits - 1 ) - 1 );
59+ minq = -static_cast <float >(std::pow (2 , bits - 1 ));
60+ } else {
61+ maxq = static_cast <float >(std::pow (2 , bits) - 1 );
62+ minq = 0 .0f ;
63+ }
5664#pragma omp parallel for
5765 for (int n = 0 ; n < N; n++) {
5866 float x_min = FLT_MAX;
@@ -76,16 +84,16 @@ void find_params(T *x, T *b_scale, T *zero, int N, int K,
7684 x_max = 1 ;
7785 }
7886 if constexpr (std::is_same<T, fp16_t >::value) {
79- b_scale[n] = utils::cast<fp16_t >((x_max - x_min) / maxq);
87+ b_scale[n] = utils::cast<fp16_t >((x_max - x_min) / ( maxq - minq) );
8088 if (sym) {
81- zero[n] = utils::cast<fp16_t >((maxq + 1 .0f ) * 0 .5f );
89+ zero[n] = utils::cast<fp16_t >((maxq + minq + 1 .0f ) * 0 .5f );
8290 } else {
83- zero[n] = utils::cast<fp16_t >(-x_min * maxq / (x_max - x_min));
91+ zero[n] = utils::cast<fp16_t >(-x_min * ( maxq - minq) / (x_max - x_min));
8492 }
8593 } else if constexpr (std::is_same<T, float >::value) {
86- b_scale[n] = (x_max - x_min) / maxq;
94+ b_scale[n] = (x_max - x_min) / ( maxq - minq) ;
8795 if (sym) {
88- zero[n] = (maxq + 1 .0f ) * 0 .5f ;
96+ zero[n] = (maxq + minq + 1 .0f ) * 0 .5f ;
8997 } else {
9098 zero[n] = -x_min / b_scale[n];
9199 }
@@ -96,11 +104,11 @@ void find_params(T *x, T *b_scale, T *zero, int N, int K,
96104 float p = 1 - static_cast <float >(i) / static_cast <float >(grid);
97105 float x_min_1 = p * x_min;
98106 float x_max_1 = p * x_max;
99- float scale_1 = (x_max_1 - x_min_1) / maxq;
107+ float scale_1 = (x_max_1 - x_min_1) / ( maxq - minq) ;
100108 float zero_1 = (sym ? utils::cast<float >(zero[n]) : std::roundf (-x_min_1 / scale_1));
101109 float err = 0 .0f ;
102110 for (int k = 0 ; k < K; k++) {
103- float q = quantize (utils::cast<float >(x[n * K + k]), scale_1, zero_1, maxq);
111+ float q = quantize (utils::cast<float >(x[n * K + k]), scale_1, zero_1, minq, maxq);
104112 q -= utils::cast<float >(x[n * K + k]);
105113 q = std::abs (q);
106114 q = static_cast <float >(std::pow (q, norm));
@@ -344,12 +352,20 @@ void fasterquant(T *weight, T *Q, float *Err, T *b_scale, T *zero, float *Hess,
344352 int M, int K, int N,
345353 int block_size = 128 , float percdamp = 0.01 , int group_size = -1 ,
346354 int bits = 4 , bool sym = false , bool mse = false ,
347- float norm = 2.4 , int grid = 100 , float maxshrink = 0.8 ) {
348- float maxq = static_cast <float >(std::pow (2 , bits) - 1 );
355+ float norm = 2.4 , int grid = 100 , float maxshrink = 0.8 , bool sign_ed = false ) {
356+ float maxq;
357+ float minq;
358+ if (sign_ed) { // 如果有符号量化
359+ maxq = static_cast <float >(std::pow (2 , bits - 1 ) - 1 );
360+ minq = -static_cast <float >(std::pow (2 , bits - 1 ));
361+ } else {
362+ maxq = static_cast <float >(std::pow (2 , bits) - 1 );
363+ minq = 0 .0f ;
364+ }
349365 int num_groups = (group_size == -1 ? 1 : K / group_size);
350366
351367 if (group_size == -1 ) {
352- find_params (weight, b_scale, zero, N, K, bits, sym, mse, norm, grid, maxshrink);
368+ find_params (weight, b_scale, zero, N, K, bits, sym, mse, norm, grid, maxshrink, sign_ed );
353369 }
354370 float damp = 0 .0f ;
355371
@@ -388,13 +404,13 @@ void fasterquant(T *weight, T *Q, float *Err, T *b_scale, T *zero, float *Hess,
388404 if ((index * block_size + i) % group_size == 0 ) {
389405 int ind = (index * block_size + i) / group_size;
390406 for (int n = 0 ; n < N; n++) {
391- find_params (&weight[n * K + index * block_size + i], &b_scale[n * num_groups + ind], &zero[n * num_groups + ind], 1 , group_size, bits, sym, mse, norm, grid, maxshrink);
407+ find_params (&weight[n * K + index * block_size + i], &b_scale[n * num_groups + ind], &zero[n * num_groups + ind], 1 , group_size, bits, sym, mse, norm, grid, maxshrink, sign_ed );
392408 }
393409 }
394410 }
395411 int ind = (group_size != -1 ? (index * block_size + i) / group_size : 0 );
396412 for (int n = 0 ; n < N; n++) {
397- float q = quantize (utils::cast<float >(weight[n * K + index * block_size + i]), utils::cast<float >(b_scale[n * num_groups + ind]), utils::cast<float >(zero[n * num_groups + ind]), maxq);
413+ float q = quantize (utils::cast<float >(weight[n * K + index * block_size + i]), utils::cast<float >(b_scale[n * num_groups + ind]), utils::cast<float >(zero[n * num_groups + ind]), minq, maxq);
398414 if constexpr (std::is_same<T, fp16_t >::value) {
399415 Q[n * K + index * block_size + i] = utils::cast<fp16_t >(q);
400416 } else if constexpr (std::is_same<T, float >::value) {
@@ -435,8 +451,16 @@ void fasterquant(T *weight, T *Q, float *Err, T *b_scale, T *zero, float *Hess,
435451}
436452
437453void PackQuantizedWeight (fp16_t *Q, fp16_t *b_scale, fp16_t *zero,
438- int32_t *packed_weight, int K, int N, int group_size, int bits = 4 ) {
439- int maxq = int (std::pow (2 , bits) - 1 );
454+ int32_t *packed_weight, int K, int N, int group_size, int bits = 4 , bool sign_ed = false ) {
455+ int maxq;
456+ int minq;
457+ if (sign_ed) { // 如果有符号量化
458+ maxq = int (std::pow (2 , bits - 1 ) - 1 );
459+ minq = -int (std::pow (2 , bits - 1 ));
460+ } else {
461+ maxq = int (std::pow (2 , bits) - 1 );
462+ minq = 0 ;
463+ }
440464 int num_groups = (group_size == -1 ) ? 1 : K / group_size;
441465 int blocks_per_group = (group_size == -1 ) ? K / 8 : group_size / 8 ;
442466
@@ -458,7 +482,7 @@ void PackQuantizedWeight(fp16_t *Q, fp16_t *b_scale, fp16_t *zero,
458482 int k = row_base + i;
459483 float val = utils::cast<float >(Q[n * K + k]); // Q: [N, K]
460484 int q = static_cast <int >(std::roundf (val / scale + zero_f));
461- q = std::max (0 , std::min (maxq, q)); // clamp to [0 , maxq]
485+ q = std::max (minq , std::min (maxq, q)); // clamp to [minq , maxq]
462486 packed |= (q & 0xF ) << (i * 4 );
463487 }
464488
@@ -518,6 +542,7 @@ void quantWeights(void *workspace, int32_t *packed_weights,
518542 int grid = 100 ;
519543 float maxshrink = 0 .8f ;
520544 float nsamples = 0 .0f ;
545+ bool sign_ed = false ;
521546
522547 char *tmp = (char *)workspace + (K * K + N * block_size) * sizeof (float );
523548 float *Hess = (float *)workspace; // [K, K]
@@ -535,9 +560,9 @@ void quantWeights(void *workspace, int32_t *packed_weights,
535560 M, K, N,
536561 block_size, percdamp, group_size,
537562 bits, sym, mse,
538- norm, grid, maxshrink);
563+ norm, grid, maxshrink, sign_ed );
539564
540- PackQuantizedWeight (Q, b_scale, zero, packed_weights, K, N, group_size, bits);
565+ PackQuantizedWeight (Q, b_scale, zero, packed_weights, K, N, group_size, bits, sign_ed );
541566}
542567
543568void caculate (void *workspace, fp16_t *C, const fp16_t *A,
0 commit comments