66
77from torchao .dtypes .uintx .Uintx import to_uintx
88from torchao .quantization .quant_api import quantize_ , uintx_weight_only
9- from torchao .utils import TORCH_VERSION_AT_LEAST_2_5
9+ from torchao .utils import (
10+ TORCH_VERSION_AT_LEAST_2_3 ,
11+ TORCH_VERSION_AT_LEAST_2_5 ,
12+ )
1013
1114from torchao .quantization .quant_primitives import (
1215 MappingType ,
1619 dequantize_affine ,
1720)
1821
19- bit_widths = (1 , 2 , 3 , 4 , 5 , 6 , 7 )
22+ # torch.uintx dtypes are introduced in 2.3
23+ if TORCH_VERSION_AT_LEAST_2_3 :
24+ dtypes = (torch .uint1 , torch .uint2 , torch .uint3 , torch .uint4 , torch .uint5 , torch .uint6 , torch .uint7 )
25+ else :
26+ dtypes = ()
27+
2028group_sizes = [32 , 64 , 128 ]
2129devices = ["cpu" , "cuda" ]
2230@pytest .fixture (autouse = True )
@@ -36,72 +44,116 @@ def __init__(self, scale, device):
3644 def forward (self , x ):
3745 return self .net (x )
3846
39- @pytest .mark .parametrize ("bit_width " , bit_widths )
47+ @pytest .mark .parametrize ("dtype " , dtypes )
4048@pytest .mark .parametrize ("group_size" , group_sizes )
4149@pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
4250@pytest .mark .skipif (not TORCH_VERSION_AT_LEAST_2_5 , reason = "only works with fix in the nightly build" )
43- def test_uintx_quant_on_cpu_then_move_to_cuda (bit_width , group_size ):
51+ def test_uintx_quant_on_cpu_then_move_to_cuda (dtype , group_size ):
4452 scale = 512
4553 fp16_mod_on_cpu = Linear16 (scale , "cpu" )
46- quantize_ (fp16_mod_on_cpu , uintx_weight_only (bit_width , group_size = group_size ))
54+ quantize_ (fp16_mod_on_cpu , uintx_weight_only (dtype , group_size = group_size ))
4755 test_input_on_cpu = torch .randn (scale * 2 , dtype = torch .float16 , device = "cpu" )
4856 output_on_cpu = fp16_mod_on_cpu (test_input_on_cpu )
4957 fp16_mod_on_cuda = fp16_mod_on_cpu .to ("cuda" )
5058 test_input_on_cuda = test_input_on_cpu .to ("cuda" )
5159 output_on_cuda = fp16_mod_on_cuda (test_input_on_cuda )
5260 assert torch .allclose (output_on_cpu , output_on_cuda .cpu (), atol = 1.0e-3 ), "The output of the model on CPU and CUDA should be close"
5361
54- @pytest .mark .parametrize ("bit_width " , bit_widths )
62+ @pytest .mark .parametrize ("dtype " , dtypes )
5563@pytest .mark .parametrize ("group_size" , group_sizes )
5664@pytest .mark .parametrize ("device" , devices )
5765@pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
5866@pytest .mark .skipif (not TORCH_VERSION_AT_LEAST_2_5 , reason = "only works with fix in the nightly build" )
59- def test_uintx_weight_only_model_quant (bit_width , group_size , device ):
67+ def test_uintx_weight_only_model_quant (dtype , group_size , device ):
6068 scale = 512
6169 fp16 = Linear16 (scale , device )
62- quantize_ (fp16 , uintx_weight_only (bit_width , group_size = group_size ))
70+ quantize_ (fp16 , uintx_weight_only (dtype , group_size = group_size ))
6371 uintx = torch .compile (fp16 , fullgraph = True )
6472 test_input = torch .randn (scale * 2 , dtype = torch .float16 , device = device )
6573 output = uintx .forward (test_input )
6674 assert output != None , "model quantization failed"
6775
68- @pytest .mark .parametrize ("bit_width " , bit_widths )
76+ @pytest .mark .parametrize ("dtype " , dtypes )
6977@pytest .mark .parametrize ("group_size" , group_sizes )
7078@pytest .mark .parametrize ("device" , devices )
7179@pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
7280@pytest .mark .skipif (not TORCH_VERSION_AT_LEAST_2_5 , reason = "only works with fix in the nightly build" )
73- def test_uintx_weight_only_quant (bit_width , group_size , device ):
81+ def test_uintx_weight_only_quant (dtype , group_size , device ):
7482 input_float = torch .randn ((1 , 256 ), dtype = torch .float16 , device = device )
7583 mapping_type = MappingType .SYMMETRIC
76- quant_min = 0
77- quant_max = 2 ** bit_width - 1
7884 eps = torch .finfo (torch .float32 ).eps
7985 zero_point_dtype = torch .int32
8086 zero_point_domain = ZeroPointDomain .INT
81- target_dtype = torch .uint8
8287 block_size = (1 , group_size )
8388
8489 scale , zero_point = choose_qparams_affine (
8590 input_float , mapping_type , block_size ,
86- target_dtype , quant_min , quant_max , eps , torch .float32 ,
87- zero_point_dtype , True , zero_point_domain
91+ dtype , eps = eps , scale_dtype = torch .float32 ,
92+ zero_point_dtype = zero_point_dtype , preserve_zero = True , zero_point_domain = zero_point_domain
8893 )
8994
9095 aqt = quantize_affine (
9196 input_float , block_size , scale ,
92- zero_point , target_dtype ,
93- quant_min = quant_min ,
94- quant_max = quant_max ,
95- zero_point_domain = zero_point_domain
97+ zero_point , dtype ,
98+ zero_point_domain = zero_point_domain
9699 )
100+ # Note: output will be uint8 tensor for sub byte tensors for now
97101
98- q = to_uintx (aqt , bit_width , - 1 )
102+ q = to_uintx (aqt , dtype , - 1 )
99103 assert q != None , "quantization failed"
100104 deqaunt = dequantize_affine (
101105 q , block_size , scale ,
102- zero_point , target_dtype ,
103- quant_min = quant_min ,
104- quant_max = quant_max ,
105- zero_point_domain = zero_point_domain
106+ zero_point , dtype ,
107+ zero_point_domain = zero_point_domain
106108 )
107109 assert deqaunt != None , "deqauntization failed"
110+
111+
112+ @pytest .mark .parametrize ("dtype" , dtypes )
113+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "Need CUDA available" )
114+ @pytest .mark .skipif (not TORCH_VERSION_AT_LEAST_2_3 , reason = "sub byte dtype requires torch 2.3+" )
115+ def test_uintx_target_dtype (dtype ):
116+ from torchao .quantization .quant_api import uintx_weight_only
117+ l = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 , device = "cuda" )
118+ # make sure it runs
119+ uintx_weight_only (dtype )(l )
120+ l (torch .randn (1 , 128 , dtype = torch .bfloat16 , device = "cuda" ))
121+
122+ @pytest .mark .parametrize ("dtype" , dtypes )
123+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "Need CUDA available" )
124+ @pytest .mark .skipif (not TORCH_VERSION_AT_LEAST_2_5 , reason = "torch.compile without unwrap_tensor_subclass requires torch 2.5+" )
125+ def test_uintx_target_dtype_compile (dtype ):
126+ from torchao .quantization .quant_api import uintx_weight_only
127+ l = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 , device = "cuda" )
128+ # make sure it runs
129+ uintx_weight_only (dtype )(l )
130+ l = torch .compile (l )
131+ l (torch .randn (1 , 128 , dtype = torch .bfloat16 , device = "cuda" ))
132+
133+
134+ @pytest .mark .parametrize ("dtype" , dtypes )
135+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "Need CUDA available" )
136+ @pytest .mark .skipif (not TORCH_VERSION_AT_LEAST_2_3 , reason = "sub byte dtype requires torch 2.3+" )
137+ def test_uintx_model_size (dtype ):
138+ from torchao .quantization .quant_api import uintx_weight_only
139+ from torchao .utils import get_model_size_in_bytes
140+ # scale size = 1/64 * 2 bytes = 1/32 bytes
141+ # zero_point size = 1/64 * 4 bytes = 1/16 bytes
142+ # dtype data size = 1 * bit_width/8 = bit_width/8 bytes
143+ _dtype_to_ratio = {
144+ torch .uint1 : (1 / 8 + 1 / 16 + 1 / 32 ) / 2 ,
145+ torch .uint2 : (2 / 8 + 1 / 16 + 1 / 32 ) / 2 ,
146+ torch .uint3 : (3 / 8 + 1 / 16 + 1 / 32 ) / 2 ,
147+ torch .uint4 : (4 / 8 + 1 / 16 + 1 / 32 ) / 2 ,
148+ torch .uint5 : (5 / 8 + 1 / 16 + 1 / 32 ) / 2 ,
149+ torch .uint6 : (6 / 8 + 1 / 16 + 1 / 32 ) / 2 ,
150+ torch .uint7 : (7 / 8 + 1 / 16 + 1 / 32 ) / 2 ,
151+ }
152+ l = torch .nn .Sequential (
153+ torch .nn .Linear (128 , 256 , bias = False , dtype = torch .bfloat16 , device = "cuda" )
154+ )
155+ bf16_size = get_model_size_in_bytes (l )
156+ # make sure it runs
157+ uintx_weight_only (dtype )(l [0 ])
158+ quantized_size = get_model_size_in_bytes (l )
159+ assert bf16_size * _dtype_to_ratio [dtype ] == quantized_size
0 commit comments