2323# Configuration (Internal Use Only)
2424# ==============================================================================
2525# These are not meant to be imported from other modules
26- is_weight_transposed = False
2726
2827_TEST_CASES = []
2928
3837for _ , layers in MODELS .items ():
3938 for layer in layers :
4039 for batch in [1 , 16 ]:
41- _TEST_CASES .append (((batch , layer [0 ], layer [1 ], is_weight_transposed )))
40+ _TEST_CASES .append (((batch , layer [0 ], layer [1 ])))
4241
4342# Data types used for testing
4443_TENSOR_DTYPES = [torch .float16 ]
@@ -324,14 +323,14 @@ def fasterquant(self, blocksize=128, percdamp=0.01, group_size=-1):
324323 self .zero = zero .to (self .weight .dtype )
325324
326325
327- def get_scale_zero (b , a , c , group_size , bits , sign_ed ):
326+ def get_scale_zero (b , a , c , group_size , bits , sym , sign_ed ):
328327 weight = b .clone ()
329328 inp = a .clone ()
330329 out = c .clone ()
331330 gptq = GPTQ (weight )
332331 gptq .quantizer = Quantizer ()
333332 gptq .quantizer .configure (
334- bits = bits , perchannel = True , sym = False , mse = False , sign_ed = sign_ed
333+ bits = bits , perchannel = True , sym = sym , mse = False , sign_ed = sign_ed
335334 )
336335 gptq .add_batch (inp , out )
337336 gptq .fasterquant (group_size = group_size )
@@ -370,7 +369,6 @@ def test(
370369 M ,
371370 K ,
372371 N ,
373- is_weight_transposed ,
374372 dtype = torch .float16 ,
375373 sync = None ,
376374):
@@ -381,8 +379,13 @@ def test(
381379 # Initialize tensors
382380 a = 1e0 * torch .randn ([K , M ], dtype = dtype ).to (torch_device )
383381 layer = nn .Linear (K , N )
384- b = 1e0 * layer .weight .data .to (dtype ).to (torch_device )
382+ b = 1e-3 * layer .weight .data .to (dtype ).to (torch_device )
385383 c = torch .zeros ([N , M ], dtype = dtype ).to (torch_device )
384+ is_weight_transposed = False
385+ sign_ed = False
386+ sym = False
387+ if torch_device != "cpu" :
388+ is_weight_transposed = True
386389
387390 group_size = - 1
388391 num_groups = 1
@@ -397,37 +400,45 @@ def test(
397400 packed_weights = torch .zeros ([N , K // 8 ], dtype = torch .int32 ).to (torch_device )
398401 s = torch .zeros ([N , num_groups ], dtype = dtype ).to (torch_device )
399402 z = torch .zeros ([N , num_groups ], dtype = dtype ).to (torch_device )
400- sign_ed = False
403+
401404 bits = 4
402405 maxq = 2 ** bits - 1
403406 minq = 0
404407 if sign_ed : # 有符号量化,范围是[-8,7]
405408 maxq = 2 ** (bits - 1 ) - 1
406409 minq = - (2 ** (bits - 1 ))
407- sym = False
408410
409411 if torch_device == "cuda" :
410412 b_ref , s , z = get_scale_zero (
411- b , a .t (), c , group_size , bits , sign_ed = sign_ed
413+ b , a .t (), c , group_size , bits , sym , sign_ed = sign_ed
412414 ) # 无符号量化
413415
414416 packed_weights = pack (b_ref , s , z , minq , maxq )
415417
416- if torch_device == "cpu" :
417- b_ref , s , z = get_scale_zero (
418- b , a .t (), c , group_size , bits , sign_ed = sign_ed
419- ) # 无符号量化
420-
421- packed_weights = pack (b_ref , s , z , minq , maxq )
418+ # if torch_device == "cpu":
419+ # b_ref, s, z = get_scale_zero(
420+ # b, a.t(), c, group_size, bits, sym, sign_ed=sign_ed
421+ # ) # 无符号量化
422422
423- a_tensor , b_tensor , c_tensor , s_tensor , z_tensor , packed_weights_tensor = (
424- to_tensor (a , lib ),
425- to_tensor (b , lib ),
426- to_tensor (c , lib ),
427- to_tensor (s , lib ),
428- to_tensor (z , lib ),
429- to_tensor (packed_weights , lib ),
430- )
423+ # packed_weights = pack(b_ref, s, z, minq, maxq)
424+ if is_weight_transposed :
425+ a_tensor , b_tensor , c_tensor , s_tensor , z_tensor , packed_weights_tensor = (
426+ to_tensor (a .t (), lib ),
427+ to_tensor (b .t (), lib ),
428+ to_tensor (c .t (), lib ),
429+ to_tensor (s .t (), lib ),
430+ to_tensor (z .t (), lib ),
431+ to_tensor (packed_weights .t (), lib ),
432+ )
433+ else :
434+ a_tensor , b_tensor , c_tensor , s_tensor , z_tensor , packed_weights_tensor = (
435+ to_tensor (a , lib ),
436+ to_tensor (b , lib ),
437+ to_tensor (c , lib ),
438+ to_tensor (s , lib ),
439+ to_tensor (z , lib ),
440+ to_tensor (packed_weights , lib ),
441+ )
431442
432443 descriptor = infiniopQuantizeGPTQDescriptor_t ()
433444 check_error (
@@ -463,19 +474,19 @@ def test(
463474 workspace = create_workspace (workspace_size .value , a .device )
464475
465476 # Execute infiniop quantize_gptq operator
466- # check_error(
467- # lib.infiniopQuantizeGPTQ(
468- # descriptor,
469- # workspace.data_ptr() if workspace is not None else None,
470- # workspace_size.value,
471- # packed_weights_tensor.data,
472- # s_tensor.data,
473- # z_tensor.data,
474- # a_tensor.data,
475- # b_tensor.data,
476- # None,
477- # )
478- # )
477+ check_error (
478+ lib .infiniopQuantizeGPTQ (
479+ descriptor ,
480+ workspace .data_ptr () if workspace is not None else None ,
481+ workspace_size .value ,
482+ packed_weights_tensor .data ,
483+ s_tensor .data ,
484+ z_tensor .data ,
485+ a_tensor .data ,
486+ b_tensor .data ,
487+ None ,
488+ )
489+ )
479490
480491 def lib_quantize_gptq ():
481492 check_error (
@@ -495,13 +506,15 @@ def lib_quantize_gptq():
495506 lib_quantize_gptq ()
496507
497508 atol , rtol = get_tolerance (_TOLERANCE_MAP , dtype )
498- tmpa = ans .flatten ()
499- tmpc = c .flatten ()
500- for i in range (tmpa .shape [0 ]):
501- if abs (tmpa [i ] - tmpc [i ]) > atol + rtol * abs (tmpa [i ]):
502- print (tmpa [i ], tmpc [i ], abs (tmpa [i ] - tmpc [i ]), rtol * abs (tmpa [i ]))
503- break
509+ # tmpa = ans.flatten()
510+ # tmpc = c.flatten()
511+ # for i in range(tmpa.shape[0]):
512+ # if abs(tmpa[i] - tmpc[i]) > atol + rtol * abs(tmpa[i]):
513+ # print(tmpa[i], tmpc[i], abs(tmpa[i] - tmpc[i]), rtol * abs(tmpa[i]))
514+ # break
504515
516+ if is_weight_transposed :
517+ c = c .t ()
505518 if DEBUG :
506519 debug (c , ans , atol = atol , rtol = rtol )
507520 assert torch .allclose (c , ans , atol = atol , rtol = rtol )
0 commit comments