@@ -365,6 +365,9 @@ def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_til
365365 weight_int4pack = torch .ops .aten ._convert_weight_to_int4pack (weight_int32 , inner_k_tiles )
366366 return weight_int4pack , scales_and_zeros
367367
368+ def _calc_padded_size (k , groupsize = 1 , innner_k_tiles = 1 ):
369+ from model import find_multiple
370+ return find_multiple (k , 1024 )
368371
369372def linear_forward_int4 (x , weight_int4pack , scales_and_zeros , out_features , groupsize ):
370373 origin_x_size = x .size ()
@@ -378,29 +381,24 @@ def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, grou
378381def _check_linear_int4_k (k , groupsize = 1 , inner_k_tiles = 1 ):
379382 return k % groupsize == 0 and k % (inner_k_tiles * 16 ) == 0
380383
381- def replace_linear_int4 (module , groupsize , inner_k_tiles , padding , use_cuda ):
384+ def replace_linear_int4 (module , groupsize , inner_k_tiles , padding_allowed , use_cuda ):
382385 for name , child in module .named_children ():
383386 if isinstance (child , nn .Linear ):
384- if _check_linear_int4_k (child .in_features , groupsize , inner_k_tiles ):
387+ if _check_linear_int4_k (child .in_features , groupsize , inner_k_tiles ) or padding_allowed :
385388 setattr (module , name , WeightOnlyInt4Linear (
386389 child .in_features , child .out_features , bias = False ,
387- groupsize = groupsize , inner_k_tiles = inner_k_tiles , padding = False , use_cuda = use_cuda
388- ))
389- elif padding :
390- setattr (module , name , WeightOnlyInt4Linear (
391- child .in_features , child .out_features , bias = False ,
392- groupsize = groupsize , inner_k_tiles = inner_k_tiles , padding = True , use_cuda = use_cuda
390+ groupsize = groupsize , inner_k_tiles = inner_k_tiles , use_cuda = use_cuda
393391 ))
394392 else :
395- replace_linear_int4 (child , groupsize , inner_k_tiles , padding , use_cuda )
393+ replace_linear_int4 (child , groupsize , inner_k_tiles , padding_allowed , use_cuda )
396394
397395
398396class WeightOnlyInt4QuantHandler :
399- def __init__ (self , mod , groupsize = 128 , inner_k_tiles = 8 , padding = True ):
397+ def __init__ (self , mod , groupsize = 128 , inner_k_tiles = 8 , padding_allowed = True ):
400398 self .mod = mod
401399 self .groupsize = groupsize
402400 self .inner_k_tiles = inner_k_tiles
403- self .padding = padding
401+ self .padding_allowed = padding_allowed
404402 assert groupsize in [32 , 64 , 128 , 256 ]
405403 assert inner_k_tiles in [2 , 4 , 8 ]
406404
@@ -417,7 +415,7 @@ def create_quantized_state_dict(self):
417415
418416 weight = mod .weight .data
419417 if not _check_linear_int4_k (in_features , self .groupsize , self .inner_k_tiles ):
420- if self .padding :
418+ if self .padding_allowed :
421419 from model import find_multiple
422420 import torch .nn .functional as F
423421 print (f"warning: { fqn } is padded to satisfy in_features % 1024 == 0" )
@@ -436,7 +434,7 @@ def create_quantized_state_dict(self):
436434 return cur_state_dict
437435
438436 def convert_for_runtime (self , use_cuda ):
439- replace_linear_int4 (self .mod , self .groupsize , self .inner_k_tiles , self .padding , use_cuda )
437+ replace_linear_int4 (self .mod , self .groupsize , self .inner_k_tiles , self .padding_allowed , use_cuda )
440438 return self .mod
441439
442440class WeightOnlyInt4GPTQQuantHandler (GPTQQuantHandler ):
@@ -460,7 +458,10 @@ def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
460458 # we need to do the padding here, both for q and the qparams if necessary
461459 def make_names_and_values_dict_func (q , qparams ):
462460 k = q .shape [1 ]
463- new_k = find_multiple (k , 1024 )
461+ if not _check_linear_int4_k (k , groupsize , inner_k_tiles ):
462+ new_k = find_multiple (k , 1024 )
463+ else :
464+ new_k = k
464465 # how much we need to pad the weight
465466 delta_k = new_k - q .shape [1 ]
466467 final_q = torch .ops .aten ._convert_weight_to_int4pack (F .pad (q , pad = (0 , delta_k )), inner_k_tiles )
@@ -485,11 +486,11 @@ class WeightOnlyInt4Linear(torch.nn.Module):
485486
486487 def __init__ (
487488 self , in_features : int , out_features : int ,
488- bias = True , device = None , dtype = None , groupsize : int = 128 , inner_k_tiles : int = 8 , padding : bool = True , use_cuda = True ,
489+ bias = True , device = None , dtype = None , groupsize : int = 128 , inner_k_tiles : int = 8 , use_cuda = True ,
489490 ) -> None :
490491 super ().__init__ ()
491- self .padding = padding
492- if padding :
492+ self .padding = not _check_linear_int4_k ( in_features , groupsize , inner_k_tiles )
493+ if self . padding :
493494 from model import find_multiple
494495 self .origin_in_features = in_features
495496 in_features = find_multiple (in_features , 1024 )
@@ -502,16 +503,10 @@ def __init__(
502503
503504 assert out_features % 8 == 0 , "require out_features % 8 == 0"
504505 assert in_features % (inner_k_tiles * 16 ) == 0 , "require in_features % (innerKTiles * 16) == 0"
505- if use_cuda :
506- self .register_buffer (
507- "weight" ,
508- torch .empty ((out_features // 8 , in_features // (inner_k_tiles * 16 ), 32 , inner_k_tiles // 2 ), dtype = torch .int32 )
509- )
510- else :
511- self .register_buffer (
512- "weight" ,
513- torch .empty ((out_features , in_features // 2 ), dtype = torch .uint8 )
514- )
506+ self .register_buffer (
507+ "weight" ,
508+ torch .empty ((out_features // 8 , in_features // (inner_k_tiles * 16 ), 32 , inner_k_tiles // 2 ), dtype = torch .int32 )
509+ )
515510 self .register_buffer (
516511 "scales_and_zeros" ,
517512 torch .empty ((in_features // groupsize , out_features , 2 ), dtype = torch .bfloat16 )
@@ -544,7 +539,7 @@ def quantize(
544539 device : str = default_device ,
545540) -> None :
546541 assert checkpoint_path .is_file (), checkpoint_path
547-
542+ device = 'cpu'
548543 precision = torch .bfloat16
549544
550545 print ("Loading model ..." )
@@ -554,6 +549,8 @@ def quantize(
554549 model = Transformer .from_name (checkpoint_path .parent .name )
555550
556551 checkpoint = torch .load (str (checkpoint_path ), mmap = True , weights_only = True )
552+ if "model" in checkpoint and "stories" in str (checkpoint_path ):
553+ checkpoint = checkpoint ["model" ]
557554 model .load_state_dict (checkpoint , assign = True )
558555 model = model .to (dtype = precision , device = device )
559556
@@ -597,7 +594,7 @@ def quantize(
597594
598595 dir_name = checkpoint_path .parent
599596 base_name = checkpoint_path .name
600- new_base_name = base_name .replace ('.pth' , f"{ label } int4-gptq.g{ groupsize } .pth" )
597+ new_base_name = base_name .replace ('.pth' , f"{ label } int4-gptq.g{ groupsize } .{ device } . pth" )
601598 else :
602599 raise ValueError (f"Invalid quantization mode { mode } needs to be one of [int8, int4, int4-gpptq]" )
603600
0 commit comments