@@ -326,8 +326,8 @@ def create_quantized_state_dict(self):
326326 for fqn , mod in self .mod .named_modules ():
327327 if isinstance (mod , torch .nn .Linear ):
328328 int8_weight , scales , _ = dynamically_quantize_per_channel (mod .weight .float (), - 128 , 127 , torch .int8 )
329- cur_state_dict [f"{ fqn } .weight" ] = int8_weight
330- cur_state_dict [f"{ fqn } .scales" ] = scales .to (mod .weight .dtype )
329+ cur_state_dict [f"{ fqn } .weight" ] = int8_weight . to ( 'cpu' )
330+ cur_state_dict [f"{ fqn } .scales" ] = scales .to (mod .weight .dtype ). to ( 'cpu' )
331331
332332 return cur_state_dict
333333
@@ -376,21 +376,21 @@ def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, grou
376376def _check_linear_int4_k (k , groupsize = 1 , inner_k_tiles = 1 ):
377377 return k % groupsize == 0 and k % (inner_k_tiles * 16 ) == 0
378378
379- def replace_linear_int4 (module , groupsize , inner_k_tiles , padding ):
379+ def replace_linear_int4 (module , groupsize , inner_k_tiles , padding , use_cuda ):
380380 for name , child in module .named_children ():
381381 if isinstance (child , nn .Linear ):
382382 if _check_linear_int4_k (child .in_features , groupsize , inner_k_tiles ):
383383 setattr (module , name , WeightOnlyInt4Linear (
384384 child .in_features , child .out_features , bias = False ,
385- groupsize = groupsize , inner_k_tiles = inner_k_tiles , padding = False ,
385+ groupsize = groupsize , inner_k_tiles = inner_k_tiles , padding = False , use_cuda = use_cuda
386386 ))
387387 elif padding :
388388 setattr (module , name , WeightOnlyInt4Linear (
389389 child .in_features , child .out_features , bias = False ,
390- groupsize = groupsize , inner_k_tiles = inner_k_tiles , padding = True ,
390+ groupsize = groupsize , inner_k_tiles = inner_k_tiles , padding = True , use_cuda = use_cuda
391391 ))
392392 else :
393- replace_linear_int4 (child , groupsize , inner_k_tiles , padding )
393+ replace_linear_int4 (child , groupsize , inner_k_tiles , padding , use_cuda )
394394
395395
396396class WeightOnlyInt4QuantHandler :
@@ -403,12 +403,7 @@ def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
403403 assert inner_k_tiles in [2 , 4 , 8 ]
404404
405405 @torch .no_grad ()
406- def create_quantized_state_dict (self , use_cuda = True ):
407- if use_cuda :
408- device = "cuda"
409- else :
410- device = "cpu"
411-
406+ def create_quantized_state_dict (self ):
412407 cur_state_dict = self .mod .state_dict ()
413408 for fqn , mod in self .mod .named_modules ():
414409 if isinstance (mod , torch .nn .Linear ):
@@ -431,15 +426,15 @@ def create_quantized_state_dict(self, use_cuda = True):
431426 "and that groupsize and inner_k_tiles*16 evenly divide into it" )
432427 continue
433428 weight_int4pack , scales_and_zeros = prepare_int4_weight_and_scales_and_zeros (
434- weight .to (torch .bfloat16 ). to ( device = device ) , self .groupsize , self .inner_k_tiles
429+ weight .to (torch .bfloat16 ), self .groupsize , self .inner_k_tiles
435430 )
436431 cur_state_dict [f"{ fqn } .weight" ] = weight_int4pack .to ('cpu' )
437432 cur_state_dict [f"{ fqn } .scales_and_zeros" ] = scales_and_zeros .to ('cpu' )
438433
439434 return cur_state_dict
440435
441- def convert_for_runtime (self ):
442- replace_linear_int4 (self .mod , self .groupsize , self .inner_k_tiles , self .padding )
436+ def convert_for_runtime (self , use_cuda ):
437+ replace_linear_int4 (self .mod , self .groupsize , self .inner_k_tiles , self .padding , use_cuda )
443438 return self .mod
444439
445440class WeightOnlyInt4GPTQQuantHandler (GPTQQuantHandler ):
@@ -476,8 +471,8 @@ def make_names_and_values_dict_func(q, qparams):
476471 super ().__init__ ()
477472
478473
479- def convert_for_runtime (self ):
480- replace_linear_int4 (self .mod , self .groupsize , self .inner_k_tiles , self .padding )
474+ def convert_for_runtime (self , use_cuda ):
475+ replace_linear_int4 (self .mod , self .groupsize , self .inner_k_tiles , self .padding , use_cuda )
481476 return self .mod
482477
483478class WeightOnlyInt4Linear (torch .nn .Module ):
@@ -488,7 +483,7 @@ class WeightOnlyInt4Linear(torch.nn.Module):
488483
489484 def __init__ (
490485 self , in_features : int , out_features : int ,
491- bias = True , device = None , dtype = None , groupsize : int = 128 , inner_k_tiles : int = 8 , padding : bool = True ,
486+ bias = True , device = None , dtype = None , groupsize : int = 128 , inner_k_tiles : int = 8 , padding : bool = True , use_cuda = True ,
492487 ) -> None :
493488 super ().__init__ ()
494489 self .padding = padding
@@ -505,10 +500,16 @@ def __init__(
505500
506501 assert out_features % 8 == 0 , "require out_features % 8 == 0"
507502 assert in_features % (inner_k_tiles * 16 ) == 0 , "require in_features % (innerKTiles * 16) == 0"
508- self .register_buffer (
509- "weight" ,
510- torch .empty ((out_features // 8 , in_features // (inner_k_tiles * 16 ), 32 , inner_k_tiles // 2 ), dtype = torch .int32 )
511- )
503+ if use_cuda :
504+ self .register_buffer (
505+ "weight" ,
506+ torch .empty ((out_features // 8 , in_features // (inner_k_tiles * 16 ), 32 , inner_k_tiles // 2 ), dtype = torch .int32 )
507+ )
508+ else :
509+ self .register_buffer (
510+ "weight" ,
511+ torch .empty ((out_features , in_features // 2 ), dtype = torch .uint8 )
512+ )
512513 self .register_buffer (
513514 "scales_and_zeros" ,
514515 torch .empty ((in_features // groupsize , out_features , 2 ), dtype = torch .bfloat16 )
@@ -538,10 +539,10 @@ def quantize(
538539 percdamp : float = .01 ,
539540 blocksize : int = 128 ,
540541 label : str = '' ,
542+ device : str = 'cuda' ,
541543) -> None :
542544 assert checkpoint_path .is_file (), checkpoint_path
543545
544- device = 'cpu'
545546 precision = torch .bfloat16
546547
547548 print ("Loading model ..." )
@@ -565,12 +566,13 @@ def quantize(
565566
566567 elif mode == 'int4' :
567568 print ("Quantizing model weights for int4 weight-only affine per-channel groupwise quantization" )
569+ print (f"Prepacking model weights in { device } optimal layout" )
568570 quant_handler = WeightOnlyInt4QuantHandler (model , groupsize )
569571 quantized_state_dict = quant_handler .create_quantized_state_dict ()
570572
571573 dir_name = checkpoint_path .parent
572574 base_name = checkpoint_path .name
573- new_base_name = base_name .replace ('.pth' , f"{ label } int4.g{ groupsize } .pth" )
575+ new_base_name = base_name .replace ('.pth' , f"{ label } int4.g{ groupsize } .{ device } . pth" )
574576
575577 elif mode == 'int4-gptq' :
576578 print ("Quantizing model weights for int4 weight-only affine per-channel groupwise quantization using GPTQ..." )
@@ -617,6 +619,7 @@ def quantize(
617619 parser .add_argument ('--percdamp' , type = float , default = .01 , help = 'gptq percentage dampening' )
618620 parser .add_argument ('--blocksize' , type = int , default = 128 , help = 'blocksize for gptq' )
619621 parser .add_argument ('--label' , type = str , default = '_' , help = 'label to add to output filename' )
622+ parser .add_argument ('--device' , type = str , default = 'cuda' , help = 'device to use' )
620623
621624 args = parser .parse_args ()
622- quantize (args .checkpoint_path , args .mode , args .groupsize , args .calibration_tasks , args .calibration_limit , args .calibration_seq_length , args .pad_calibration_inputs , args .percdamp , args .blocksize , args .label )
625+ quantize (args .checkpoint_path , args .mode , args .groupsize , args .calibration_tasks , args .calibration_limit , args .calibration_seq_length , args .pad_calibration_inputs , args .percdamp , args .blocksize , args .label , args . device )
0 commit comments