@@ -317,20 +317,22 @@ def fasterquant(self, blocksize=128, percdamp=0.01, group_size=-1):
317317
318318 W [:, i2 :] -= Err1 .matmul (Hinv [i1 :i2 , i2 :])
319319
320- # print(' error' , torch.sum(Losses).item())
320+ print (" error" , torch .sum (Losses ).item ())
321321
322322 self .weight = Q .reshape (self .weight .shape ).to (self .weight .dtype )
323323 self .scale = scale .to (self .weight .dtype )
324324 self .zero = zero .to (self .weight .dtype )
325325
326326
327- def get_scale_zero (b , a , c , group_size , sign_ed ):
327+ def get_scale_zero (b , a , c , group_size , bits , sign_ed ):
328328 weight = b .clone ()
329329 inp = a .clone ()
330330 out = c .clone ()
331331 gptq = GPTQ (weight )
332332 gptq .quantizer = Quantizer ()
333- gptq .quantizer .configure (perchannel = True , sym = False , mse = False , signed = sign_ed )
333+ gptq .quantizer .configure (
334+ bits = bits , perchannel = True , sym = False , mse = False , sign_ed = sign_ed
335+ )
334336 gptq .add_batch (inp , out )
335337 gptq .fasterquant (group_size = group_size )
336338
@@ -341,8 +343,10 @@ def get_scale_zero(b, a, c, group_size, sign_ed):
341343 )
342344
343345
344- def pack (weight , scale , zero ):
345- intweight = torch .round ((weight + zero ) / scale ).to (torch .int32 )
346+ def pack (weight , scale , zero , minq , maxq ):
347+ intweight = torch .clamp (torch .round (weight / scale + zero ), minq , maxq ).to (
348+ torch .int32
349+ )
346350 qweight = torch .zeros (
347351 [weight .shape [0 ], weight .shape [1 ] // 8 ], dtype = torch .int32 , device = weight .device
348352 )
@@ -377,7 +381,7 @@ def test(
377381 # Initialize tensors
378382 a = 1e0 * torch .randn ([K , M ], dtype = dtype ).to (torch_device )
379383 layer = nn .Linear (K , N )
380- b = 1e-3 * layer .weight .data .to (dtype ).to (torch_device )
384+ b = 1e0 * layer .weight .data .to (dtype ).to (torch_device )
381385 c = torch .zeros ([N , M ], dtype = dtype ).to (torch_device )
382386
383387 group_size = - 1
@@ -393,13 +397,28 @@ def test(
393397 packed_weights = torch .zeros ([N , K // 8 ], dtype = torch .int32 ).to (torch_device )
394398 s = torch .zeros ([N , num_groups ], dtype = dtype ).to (torch_device )
395399 z = torch .zeros ([N , num_groups ], dtype = dtype ).to (torch_device )
400+ sign_ed = False
401+ bits = 4
402+ maxq = 2 ** bits - 1
403+ minq = 0
404+ if sign_ed : # 有符号量化,范围是[-8,7]
405+ maxq = 2 ** (bits - 1 ) - 1
406+ minq = - (2 ** (bits - 1 ))
407+ sym = False
408+
396409 if torch_device == "cuda" :
397410 b_ref , s , z = get_scale_zero (
398- b , a .t (), c , group_size , signed = False
411+ b , a .t (), c , group_size , bits , sign_ed = sign_ed
399412 ) # 无符号量化
400- z = torch .zeros_like (s )
401- packed_weights = pack (b_ref , s , z )
402- # print(s)
413+
414+ packed_weights = pack (b_ref , s , z , minq , maxq )
415+
416+ if torch_device == "cpu" :
417+ b_ref , s , z = get_scale_zero (
418+ b , a .t (), c , group_size , bits , sign_ed = sign_ed
419+ ) # 无符号量化
420+
421+ packed_weights = pack (b_ref , s , z , minq , maxq )
403422
404423 a_tensor , b_tensor , c_tensor , s_tensor , z_tensor , packed_weights_tensor = (
405424 to_tensor (a , lib ),
@@ -444,19 +463,19 @@ def test(
444463 workspace = create_workspace (workspace_size .value , a .device )
445464
446465 # Execute infiniop quantize_gptq operator
447- check_error (
448- lib .infiniopQuantizeGPTQ (
449- descriptor ,
450- workspace .data_ptr () if workspace is not None else None ,
451- workspace_size .value ,
452- packed_weights_tensor .data ,
453- s_tensor .data ,
454- z_tensor .data ,
455- a_tensor .data ,
456- b_tensor .data ,
457- None ,
458- )
459- )
466+ # check_error(
467+ # lib.infiniopQuantizeGPTQ(
468+ # descriptor,
469+ # workspace.data_ptr() if workspace is not None else None,
470+ # workspace_size.value,
471+ # packed_weights_tensor.data,
472+ # s_tensor.data,
473+ # z_tensor.data,
474+ # a_tensor.data,
475+ # b_tensor.data,
476+ # None,
477+ # )
478+ # )
460479
461480 def lib_quantize_gptq ():
462481 check_error (
@@ -476,12 +495,12 @@ def lib_quantize_gptq():
476495 lib_quantize_gptq ()
477496
478497 atol , rtol = get_tolerance (_TOLERANCE_MAP , dtype )
479- # tmpa = ans.flatten()
480- # tmpc = c.flatten()
481- # for i in range(tmpa.shape[0]):
482- # if abs(tmpa[i] - tmpc[i]) > atol + rtol * abs(tmpa[i]):
483- # print(tmpa[i], tmpc[i], abs(tmpa[i] - tmpc[i]), rtol * abs(tmpa[i]))
484- # break
498+ tmpa = ans .flatten ()
499+ tmpc = c .flatten ()
500+ for i in range (tmpa .shape [0 ]):
501+ if abs (tmpa [i ] - tmpc [i ]) > atol + rtol * abs (tmpa [i ]):
502+ print (tmpa [i ], tmpc [i ], abs (tmpa [i ] - tmpc [i ]), rtol * abs (tmpa [i ]))
503+ break
485504
486505 if DEBUG :
487506 debug (c , ans , atol = atol , rtol = rtol )
0 commit comments