1717except :
1818 pass
1919
20- from model import Transformer
20+ from model import Transformer , find_multiple
2121
2222##### Quantization Primitives ######
2323
@@ -365,29 +365,27 @@ def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, grou
365365def _check_linear_int4_k (k , groupsize = 1 , inner_k_tiles = 1 ):
366366 return k % groupsize == 0 and k % (inner_k_tiles * 16 ) == 0
367367
368- def replace_linear_int4 (module , groupsize , inner_k_tiles , padding ):
368+ def _calc_padded_size_linear_int4 (k , groupsize = 1 , inner_k_tiles = 1 ):
369+ return find_multiple (k , groupsize , inner_k_tiles * 16 )
370+
371+ def replace_linear_int4 (module , groupsize , inner_k_tiles , padding_allowed ):
369372 for name , child in module .named_children ():
370373 if isinstance (child , nn .Linear ):
371- if _check_linear_int4_k (child .in_features , groupsize , inner_k_tiles ):
374+ if _check_linear_int4_k (child .in_features , groupsize , inner_k_tiles ) or padding_allowed :
372375 setattr (module , name , WeightOnlyInt4Linear (
373376 child .in_features , child .out_features , bias = False ,
374- groupsize = groupsize , inner_k_tiles = inner_k_tiles , padding = False ,
375- ))
376- elif padding :
377- setattr (module , name , WeightOnlyInt4Linear (
378- child .in_features , child .out_features , bias = False ,
379- groupsize = groupsize , inner_k_tiles = inner_k_tiles , padding = True ,
377+ groupsize = groupsize , inner_k_tiles = inner_k_tiles ,
380378 ))
381379 else :
382- replace_linear_int4 (child , groupsize , inner_k_tiles , padding )
380+ replace_linear_int4 (child , groupsize , inner_k_tiles , padding_allowed )
383381
384382
385383class WeightOnlyInt4QuantHandler :
386- def __init__ (self , mod , groupsize = 128 , inner_k_tiles = 8 , padding = True ):
384+ def __init__ (self , mod , groupsize = 128 , inner_k_tiles = 8 , padding_allowed = True ):
387385 self .mod = mod
388386 self .groupsize = groupsize
389387 self .inner_k_tiles = inner_k_tiles
390- self .padding = padding
388+ self .padding_allowed = padding_allowed
391389 assert groupsize in [32 , 64 , 128 , 256 ]
392390 assert inner_k_tiles in [2 , 4 , 8 ]
393391
@@ -409,11 +407,9 @@ def create_quantized_state_dict(self, use_cuda = True):
409407
410408 weight = mod .weight .data
411409 if not _check_linear_int4_k (in_features , self .groupsize , self .inner_k_tiles ):
412- if self .padding :
413- from model import find_multiple
414- import torch .nn .functional as F
410+ if self .padding_allowed :
415411 print (f"warning: { fqn } is padded to satisfy in_features % 1024 == 0" )
416- padded_in_features = find_multiple (in_features , 1024 )
412+ padded_in_features = _calc_padded_size_linear_int4 (in_features , 1024 )
417413 weight = F .pad (weight , pad = (0 , padded_in_features - in_features ))
418414 else :
419415 print (f"warning: { fqn } is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " +
@@ -428,31 +424,30 @@ def create_quantized_state_dict(self, use_cuda = True):
428424 return cur_state_dict
429425
430426 def convert_for_runtime (self ):
431- replace_linear_int4 (self .mod , self .groupsize , self .inner_k_tiles , self .padding )
427+ replace_linear_int4 (self .mod , self .groupsize , self .inner_k_tiles , self .padding_allowed )
432428 return self .mod
433429
434430class WeightOnlyInt4GPTQQuantHandler (GPTQQuantHandler ):
435- def __init__ (self , mod , groupsize = 128 , inner_k_tiles = 8 , padding = True ):
436- from model import find_multiple
431+ def __init__ (self , mod , groupsize = 128 , inner_k_tiles = 8 , padding_allowed = True ):
437432 self .mod = mod
438433 self .groupsize = groupsize
439434 self .inner_k_tiles = inner_k_tiles
440- self .padding = padding
435+ self .padding_allowed = padding_allowed
441436 self .get_qparams_func = lambda w : get_group_qparams (w , 4 , groupsize )
442437 self .quantize_func = lambda w , qparams : \
443438 group_quantize_tensor_from_qparams (w , qparams [0 ], qparams [1 ], 4 , groupsize )
444439 self .dequantize_func = lambda q , qparams : \
445440 group_dequantize_tensor_from_qparams (q , qparams [0 ], qparams [1 ], 4 , groupsize ).float ()
446441 self .combine_qparams_list_func = lambda qparams_list : \
447442 [torch .cat (x , dim = 1 ) for x in zip (* qparams_list )]
448- # skip unless padding =True or its correctly sized
443+ # skip unless padding_allowed =True or its correctly sized
449444 self .skip_layer_func = lambda linear_weight : not (
450- _check_linear_int4_k (linear_weight .shape [- 1 ], groupsize , inner_k_tiles ) or padding
445+ _check_linear_int4_k (linear_weight .shape [- 1 ], groupsize , inner_k_tiles ) or padding_allowed
451446 )
452447 # we need to do the padding here, both for q and the qparams if necessary
453448 def make_names_and_values_dict_func (q , qparams ):
454449 k = q .shape [1 ]
455- new_k = find_multiple (k , 1024 )
450+ new_k = _calc_padded_size_linear_int4 (k , groupsize , inner_k_tiles )
456451 # how much we need to pad the weight
457452 delta_k = new_k - q .shape [1 ]
458453 final_q = torch .ops .aten ._convert_weight_to_int4pack (F .pad (q , pad = (0 , delta_k )), inner_k_tiles )
@@ -466,7 +461,7 @@ def make_names_and_values_dict_func(q, qparams):
466461
467462
468463 def convert_for_runtime (self ):
469- replace_linear_int4 (self .mod , self .groupsize , self .inner_k_tiles , self .padding )
464+ replace_linear_int4 (self .mod , self .groupsize , self .inner_k_tiles , self .padding_allowed )
470465 return self .mod
471466
472467class WeightOnlyInt4Linear (torch .nn .Module ):
@@ -477,17 +472,16 @@ class WeightOnlyInt4Linear(torch.nn.Module):
477472
478473 def __init__ (
479474 self , in_features : int , out_features : int ,
480- bias = True , device = None , dtype = None , groupsize : int = 128 , inner_k_tiles : int = 8 , padding : bool = True ,
475+ bias = True , device = None , dtype = None , groupsize : int = 128 , inner_k_tiles : int = 8 ,
481476 ) -> None :
482477 super ().__init__ ()
483- self .padding = padding
484- if padding :
485- from model import find_multiple
486- self .origin_in_features = in_features
487- in_features = find_multiple (in_features , 1024 )
488478
479+ # always pad if needed since it becomes a noop at runtime if not needed
480+ self .origin_in_features = in_features
481+ in_features = _calc_padded_size_linear_int4 (in_features , groupsize , inner_k_tiles )
489482 self .in_features = in_features
490483 self .out_features = out_features
484+
491485 assert not bias , "require bias=False"
492486 self .groupsize = groupsize
493487 self .inner_k_tiles = inner_k_tiles
@@ -505,9 +499,7 @@ def __init__(
505499
506500 def forward (self , input : torch .Tensor ) -> torch .Tensor :
507501 input = input .to (torch .bfloat16 )
508- if self .padding :
509- import torch .nn .functional as F
510- input = F .pad (input , pad = (0 , self .in_features - self .origin_in_features ))
502+ input = F .pad (input , pad = (0 , self .in_features - self .origin_in_features ))
511503 return linear_forward_int4 (
512504 input ,
513505 self .weight , self .scales_and_zeros , self .out_features , self .groupsize
0 commit comments