Skip to content

Commit 7259593

Browse files
committed
issue/170: modified Hess
1 parent 7d89170 commit 7259593

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

src/infiniop/ops/matmul_gptq/cpu/matmul_gptq_cpu.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,8 @@ void fasterquant(T *weight, T *Q, T *Err, T *b_scale, T *zero, float *Hess,
275275
}
276276

277277
void PackQuantizedWeight(fp16_t *Q, fp16_t *b_scale, fp16_t *zero,
278-
int32_t *packed_weight, int K, int N, int group_size) {
278+
int32_t *packed_weight, int K, int N, int group_size, int bits = 4) {
279+
int maxq = int(std::pow(2, bits) - 1);
279280
int num_groups = (group_size == -1) ? 1 : K / group_size;
280281
int blocks_per_group = (group_size == -1) ? K / 8 : group_size / 8;
281282

@@ -297,7 +298,7 @@ void PackQuantizedWeight(fp16_t *Q, fp16_t *b_scale, fp16_t *zero,
297298
int k = row_base + i;
298299
float val = utils::cast<float>(Q[n * K + k]); // Q: [N, K]
299300
int q = static_cast<int>(std::roundf(val / scale + zero_f));
300-
q = std::max(0, std::min(15, q)); // clamp to [0, 15]
301+
q = std::max(0, std::min(maxq, q)); // clamp to [0, maxq]
301302
packed |= (q & 0xF) << (i * 4);
302303
}
303304

@@ -364,14 +365,15 @@ void quantWeights(void *workspace, int32_t *packed_weights,
364365
fp16_t *Q = (fp16_t *)tmp; //[N, K]
365366
fp16_t *weight = Q + N * K; //[N, K]
366367
fp16_t *Err = weight + N * K; //[N, blocksize=128]
368+
memset(Hess, 0, sizeof(float) * K * K);
367369
memcpy(weight, B, N * K * sizeof(fp16_t));
368370
add_batch<fp16_t>(A, Hess, nsamples, M, K);
369371
fasterquant<fp16_t>(weight, Q, Err, b_scale, zero, Hess,
370372
M, K, N,
371373
blocksize, percdamp, group_size,
372374
bits, sym, mse,
373375
norm, grid, maxshrink);
374-
PackQuantizedWeight(Q, b_scale, zero, packed_weights, K, N, group_size);
376+
PackQuantizedWeight(Q, b_scale, zero, packed_weights, K, N, group_size, bits);
375377
}
376378

377379
void caculate(void *workspace, fp16_t *C, const fp16_t *A,

xmake/cpu.lua

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
if not is_plat("windows") then
2-
add_requires("lapack", {configs = {shared = true}})
3-
end
1+
add_requires("lapack", {configs = {shared = true}})
42
target("infiniop-cpu")
53
set_kind("static")
64
add_deps("infini-utils")

0 commit comments

Comments
 (0)