@@ -275,7 +275,8 @@ void fasterquant(T *weight, T *Q, T *Err, T *b_scale, T *zero, float *Hess,
275275}
276276
277277void PackQuantizedWeight (fp16_t *Q, fp16_t *b_scale, fp16_t *zero,
278- int32_t *packed_weight, int K, int N, int group_size) {
278+ int32_t *packed_weight, int K, int N, int group_size, int bits = 4 ) {
279+ int maxq = int (std::pow (2 , bits) - 1 );
279280 int num_groups = (group_size == -1 ) ? 1 : K / group_size;
280281 int blocks_per_group = (group_size == -1 ) ? K / 8 : group_size / 8 ;
281282
@@ -297,7 +298,7 @@ void PackQuantizedWeight(fp16_t *Q, fp16_t *b_scale, fp16_t *zero,
297298 int k = row_base + i;
298299 float val = utils::cast<float >(Q[n * K + k]); // Q: [N, K]
299300 int q = static_cast <int >(std::roundf (val / scale + zero_f));
300- q = std::max (0 , std::min (15 , q)); // clamp to [0, 15 ]
301+ q = std::max (0 , std::min (maxq , q)); // clamp to [0, maxq ]
301302 packed |= (q & 0xF ) << (i * 4 );
302303 }
303304
@@ -364,14 +365,15 @@ void quantWeights(void *workspace, int32_t *packed_weights,
364365 fp16_t *Q = (fp16_t *)tmp; // [N, K]
365366 fp16_t *weight = Q + N * K; // [N, K]
366367 fp16_t *Err = weight + N * K; // [N, blocksize=128]
368+ memset (Hess, 0 , sizeof (float ) * K * K);
367369 memcpy (weight, B, N * K * sizeof (fp16_t ));
368370 add_batch<fp16_t >(A, Hess, nsamples, M, K);
369371 fasterquant<fp16_t >(weight, Q, Err, b_scale, zero, Hess,
370372 M, K, N,
371373 blocksize, percdamp, group_size,
372374 bits, sym, mse,
373375 norm, grid, maxshrink);
374- PackQuantizedWeight (Q, b_scale, zero, packed_weights, K, N, group_size);
376+ PackQuantizedWeight (Q, b_scale, zero, packed_weights, K, N, group_size, bits );
375377}
376378
377379void caculate (void *workspace, fp16_t *C, const fp16_t *A,
0 commit comments