Skip to content

Commit de70b25

Browse files
authored
[safetensors] Fix GPTQ/AWQ quantized model parameter counting (#1770)
### Fix GPTQ/AWQ quantized model parameter counting Fixes parameter count calculation for GPTQ/AWQ quantized models by: Applying 8x multiplier to qweight tensors based on quantization bits (32/4 = 8 for 4-bit) Skipping auxiliary tensors (qzeros, g_idx, scales) from parameter count Defaulting quantized tensor detection when no exclusion list is provided Before: [RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w4a16](https://huggingface.co/RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w4a16) reported ~2B parameters After: Correctly reports ~8B parameters
1 parent 3c573a5 commit de70b25

File tree

2 files changed

+73
-2
lines changed

2 files changed

+73
-2
lines changed

packages/hub/src/lib/parse-safetensors-metadata.spec.ts

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,31 @@ describe("parseSafetensorsMetadata", () => {
207207
assert.strictEqual(parameterCount.E8M0, 24);
208208
});
209209

210+
it("fetch info for GPTQ quantized 8B model", async () => {
211+
const parse = await parseSafetensorsMetadata({
212+
repo: "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w4a16",
213+
revision: "3921b6aee65496a708b0af456c964ceca7423193",
214+
computeParametersCount: true,
215+
});
216+
217+
const parameterCount = parse.parameterCount;
218+
assert.ok(parameterCount);
219+
assert.ok(parameterCount.I32);
220+
assert.ok(parameterCount.F16);
221+
assert.strictEqual(parameterCount.I32, 6_979_321_856);
222+
assert.strictEqual(parameterCount.F16, 1_052_315_648);
223+
224+
const parameterCountTotal =
225+
parse.parameterTotal ??
226+
sum(
227+
Object.entries(parameterCount)
228+
.filter(([, value]) => typeof value === "number")
229+
.map(([, value]) => value as number)
230+
);
231+
232+
assert.strictEqual(parameterCountTotal, 8_031_637_504);
233+
});
234+
210235
it("fetch info for openai/gpt-oss-20b (large sharded model)", async () => {
211236
const parse = await parseSafetensorsMetadata({
212237
repo: "openai/gpt-oss-20b",

packages/hub/src/lib/parse-safetensors-metadata.ts

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ export function parseSafetensorsShardFilename(filename: string): SafetensorsShar
3636

3737
const PARALLEL_DOWNLOADS = 20;
3838
const MAX_HEADER_LENGTH = 25_000_000;
39+
const GPTQ_QWEIGHT_SUFFIX = "qweight";
3940

4041
class SafetensorParseError extends Error {}
4142

@@ -362,10 +363,14 @@ export interface ModelConfig {
362363
* Determines if a tensor is quantized based on quantization config and tensor name
363364
*/
364365
function isQuantizedTensor(tensorName: string, quantConfig?: QuantizationConfig): boolean {
365-
if (!quantConfig || !quantConfig.modules_to_not_convert) {
366+
if (!quantConfig) {
366367
return false;
367368
}
368369

370+
if (!quantConfig.modules_to_not_convert || quantConfig.modules_to_not_convert.length === 0) {
371+
return true;
372+
}
373+
369374
for (const pattern of quantConfig.modules_to_not_convert) {
370375
const regexPattern = pattern.replace(/\*/g, ".*");
371376
const regex = new RegExp(regexPattern);
@@ -385,7 +390,9 @@ function getQuantizationMultiplier(tensorName: string, dtype: Dtype, quantConfig
385390
return 1;
386391
}
387392

388-
switch (quantConfig.quant_method) {
393+
const quantMethod = quantConfig.quant_method?.toLowerCase();
394+
395+
switch (quantMethod) {
389396
case "mxfp4":
390397
if (dtype === "U8" && tensorName.includes("_blocks")) {
391398
return 2;
@@ -394,6 +401,10 @@ function getQuantizationMultiplier(tensorName: string, dtype: Dtype, quantConfig
394401

395402
case "gptq":
396403
case "awq":
404+
if (getTensorSuffix(tensorName) === GPTQ_QWEIGHT_SUFFIX) {
405+
const bits = quantConfig.bits && quantConfig.bits > 0 ? quantConfig.bits : 4;
406+
return Math.max(1, Math.floor(32 / bits));
407+
}
397408
if (quantConfig.bits === 4 && dtype === "U8") {
398409
return 2;
399410
}
@@ -430,12 +441,18 @@ function computeNumOfParamsByDtypeSingleFile(
430441
const tensors = omit(header, "__metadata__");
431442

432443
for (const [tensorName, v] of typedEntries(tensors)) {
444+
if (shouldSkipTensor(tensorName, quantConfig)) {
445+
continue;
446+
}
433447
if (v.shape.length === 0) {
434448
continue;
435449
}
436450

437451
const elements = v.shape.reduce((a, b) => a * b);
438452
const multiplier = quantConfig ? getQuantizationMultiplier(tensorName, v.dtype, quantConfig) : 1;
453+
if (multiplier === 0) {
454+
continue;
455+
}
439456
counter[v.dtype] = (counter[v.dtype] ?? 0) + elements * multiplier;
440457
}
441458
return counter;
@@ -453,3 +470,32 @@ function computeNumOfParamsByDtypeSharded(
453470
}
454471
return counter;
455472
}
473+
474+
function getTensorSuffix(tensorName: string): string {
475+
const lastDotIndex = tensorName.lastIndexOf(".");
476+
return lastDotIndex === -1 ? tensorName : tensorName.slice(lastDotIndex + 1);
477+
}
478+
479+
function shouldSkipTensor(tensorName: string, quantConfig?: QuantizationConfig): boolean {
480+
const GPTQ_AWQ_AUXILIARY_SUFFIXES = ["qzeros", "g_idx", "scales"];
481+
482+
if (!quantConfig) {
483+
return false;
484+
}
485+
486+
const quantMethod = quantConfig.quant_method?.toLowerCase();
487+
if (!quantMethod || (quantMethod !== "gptq" && quantMethod !== "awq")) {
488+
return false;
489+
}
490+
491+
if (!isQuantizedTensor(tensorName, quantConfig)) {
492+
return false;
493+
}
494+
495+
const suffix = getTensorSuffix(tensorName);
496+
if (suffix === GPTQ_QWEIGHT_SUFFIX) {
497+
return false;
498+
}
499+
500+
return GPTQ_AWQ_AUXILIARY_SUFFIXES.includes(suffix);
501+
}

0 commit comments

Comments
 (0)