@@ -124,8 +124,8 @@ def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128)
124124 .to (torch .int32 )
125125 .reshape_as (w )
126126 )
127-
128- return w_int32
127+ w_uint8 = ( w_int32 [::,:: 2 ] << 4 | w_int32 [::, 1 :: 2 ]). to ( torch . uint8 )
128+ return w_uint8
129129
130130
131131def group_quantize_tensor (w , n_bit = 4 , groupsize = 128 ):
@@ -357,10 +357,9 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
357357##### weight only int4 per channel groupwise quantized code ######
358358
359359def prepare_int4_weight_and_scales_and_zeros (weight_bf16 , groupsize , inner_k_tiles ):
360- weight_int32 , scales_and_zeros = group_quantize_tensor (
360+ weight_int4pack , scales_and_zeros = group_quantize_tensor (
361361 weight_bf16 , n_bit = 4 , groupsize = groupsize
362362 )
363- weight_int4pack = torch .ops .aten ._convert_weight_to_int4pack (weight_int32 , inner_k_tiles )
364363 return weight_int4pack , scales_and_zeros
365364
366365
@@ -404,7 +403,7 @@ def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
404403
405404 @torch .no_grad ()
406405 def create_quantized_state_dict (self , use_cuda = True ):
407- if use_cuda :
406+ if use_cuda and torch . cuda . is_available () :
408407 device = "cuda"
409408 else :
410409 device = "cpu"
@@ -507,7 +506,7 @@ def __init__(
507506 assert in_features % (inner_k_tiles * 16 ) == 0 , "require in_features % (innerKTiles * 16) == 0"
508507 self .register_buffer (
509508 "weight" ,
510- torch .empty ((out_features // 8 , in_features // ( inner_k_tiles * 16 ), 32 , inner_k_tiles // 2 ), dtype = torch .int32 )
509+ torch .empty ((out_features , in_features // 2 ), dtype = torch .uint8 )
511510 )
512511 self .register_buffer (
513512 "scales_and_zeros" ,
0 commit comments