@@ -36,6 +36,7 @@ export function parseSafetensorsShardFilename(filename: string): SafetensorsShar
3636
3737const PARALLEL_DOWNLOADS = 20 ;
3838const MAX_HEADER_LENGTH = 25_000_000 ;
39+ const GPTQ_QWEIGHT_SUFFIX = "qweight" ;
3940
4041class SafetensorParseError extends Error { }
4142
@@ -362,10 +363,14 @@ export interface ModelConfig {
362363 * Determines if a tensor is quantized based on quantization config and tensor name
363364 */
364365function isQuantizedTensor ( tensorName : string , quantConfig ?: QuantizationConfig ) : boolean {
365- if ( ! quantConfig || ! quantConfig . modules_to_not_convert ) {
366+ if ( ! quantConfig ) {
366367 return false ;
367368 }
368369
370+ if ( ! quantConfig . modules_to_not_convert || quantConfig . modules_to_not_convert . length === 0 ) {
371+ return true ;
372+ }
373+
369374 for ( const pattern of quantConfig . modules_to_not_convert ) {
370375 const regexPattern = pattern . replace ( / \* / g, ".*" ) ;
371376 const regex = new RegExp ( regexPattern ) ;
@@ -385,7 +390,9 @@ function getQuantizationMultiplier(tensorName: string, dtype: Dtype, quantConfig
385390 return 1 ;
386391 }
387392
388- switch ( quantConfig . quant_method ) {
393+ const quantMethod = quantConfig . quant_method ?. toLowerCase ( ) ;
394+
395+ switch ( quantMethod ) {
389396 case "mxfp4" :
390397 if ( dtype === "U8" && tensorName . includes ( "_blocks" ) ) {
391398 return 2 ;
@@ -394,6 +401,10 @@ function getQuantizationMultiplier(tensorName: string, dtype: Dtype, quantConfig
394401
395402 case "gptq" :
396403 case "awq" :
404+ if ( getTensorSuffix ( tensorName ) === GPTQ_QWEIGHT_SUFFIX ) {
405+ const bits = quantConfig . bits && quantConfig . bits > 0 ? quantConfig . bits : 4 ;
406+ return Math . max ( 1 , Math . floor ( 32 / bits ) ) ;
407+ }
397408 if ( quantConfig . bits === 4 && dtype === "U8" ) {
398409 return 2 ;
399410 }
@@ -430,12 +441,18 @@ function computeNumOfParamsByDtypeSingleFile(
430441 const tensors = omit ( header , "__metadata__" ) ;
431442
432443 for ( const [ tensorName , v ] of typedEntries ( tensors ) ) {
444+ if ( shouldSkipTensor ( tensorName , quantConfig ) ) {
445+ continue ;
446+ }
433447 if ( v . shape . length === 0 ) {
434448 continue ;
435449 }
436450
437451 const elements = v . shape . reduce ( ( a , b ) => a * b ) ;
438452 const multiplier = quantConfig ? getQuantizationMultiplier ( tensorName , v . dtype , quantConfig ) : 1 ;
453+ if ( multiplier === 0 ) {
454+ continue ;
455+ }
439456 counter [ v . dtype ] = ( counter [ v . dtype ] ?? 0 ) + elements * multiplier ;
440457 }
441458 return counter ;
@@ -453,3 +470,32 @@ function computeNumOfParamsByDtypeSharded(
453470 }
454471 return counter ;
455472}
473+
474+ function getTensorSuffix ( tensorName : string ) : string {
475+ const lastDotIndex = tensorName . lastIndexOf ( "." ) ;
476+ return lastDotIndex === - 1 ? tensorName : tensorName . slice ( lastDotIndex + 1 ) ;
477+ }
478+
479+ function shouldSkipTensor ( tensorName : string , quantConfig ?: QuantizationConfig ) : boolean {
480+ const GPTQ_AWQ_AUXILIARY_SUFFIXES = [ "qzeros" , "g_idx" , "scales" ] ;
481+
482+ if ( ! quantConfig ) {
483+ return false ;
484+ }
485+
486+ const quantMethod = quantConfig . quant_method ?. toLowerCase ( ) ;
487+ if ( ! quantMethod || ( quantMethod !== "gptq" && quantMethod !== "awq" ) ) {
488+ return false ;
489+ }
490+
491+ if ( ! isQuantizedTensor ( tensorName , quantConfig ) ) {
492+ return false ;
493+ }
494+
495+ const suffix = getTensorSuffix ( tensorName ) ;
496+ if ( suffix === GPTQ_QWEIGHT_SUFFIX ) {
497+ return false ;
498+ }
499+
500+ return GPTQ_AWQ_AUXILIARY_SUFFIXES . includes ( suffix ) ;
501+ }
0 commit comments