1616 _replace_with_custom_fn_if_matches_filter ,
1717 quantize_ ,
1818)
19- from torchao .quantization .subclass import (
20- Int4WeightOnlyQuantizedLinearWeight ,
21- Int8WeightOnlyQuantizedLinearWeight ,
22- )
23-
24-
25- def _int8wo_api (mod , ** kwargs ):
26- quantize_ (mod , Int8WeightOnlyConfig (** kwargs ), set_inductor_config = False )
27-
28-
29- def _int8da_int8w_api (mod , ** kwargs ):
30- quantize_ (
31- mod ,
32- Int8DynamicActivationInt8WeightConfig (** kwargs ),
33- set_inductor_config = False ,
34- )
35-
36-
37- def _int4wo_api (mod , ** kwargs ):
38- kwargs_copy = kwargs .copy ()
39- if "groupsize" in kwargs_copy :
40- kwargs_copy ["group_size" ] = kwargs_copy ["groupsize" ]
41- del kwargs_copy ["groupsize" ]
42- quantize_ (mod , Int4WeightOnlyConfig (** kwargs_copy ), set_inductor_config = False )
4319
4420
4521class ToyLinearModel (torch .nn .Module ):
@@ -68,34 +44,6 @@ def forward(self, x):
6844 return x
6945
7046
71- def _ref_change_linear_weights_to_int8_dqtensors (model , filter_fn = None , ** kwargs ):
72- """
73- The deprecated implementation for int8 dynamic quant API, used as a reference for
74- numerics and performance
75- """
76- from torchao .quantization .quant_api import (
77- _get_subclass_inserter ,
78- _is_linear ,
79- )
80- from torchao .quantization .subclass import Int8DynamicallyQuantizedLinearWeight
81-
82- def _in_features_greater_than_16 (mod , * args ):
83- return hasattr (mod , "in_features" ) and mod .in_features > 16
84-
85- if filter_fn is None :
86- filter_fn = lambda * args : _is_linear (* args ) and _in_features_greater_than_16 (
87- * args
88- )
89-
90- _replace_with_custom_fn_if_matches_filter (
91- model ,
92- _get_subclass_inserter (
93- Int8DynamicallyQuantizedLinearWeight , enable_parametrization = False , ** kwargs
94- ),
95- filter_fn ,
96- )
97-
98-
9947def _get_ref_change_linear_weights_to_woqtensors (deprecated_tenosr_subclass ):
10048 def _ref_change_linear_weights_to_woqtensors (model , filter_fn = None , ** kwargs ):
10149 """
@@ -117,38 +65,18 @@ def _ref_change_linear_weights_to_woqtensors(model, filter_fn=None, **kwargs):
11765 return _ref_change_linear_weights_to_woqtensors
11866
11967
120- _ref_change_linear_weights_to_int8_woqtensors = (
121- _get_ref_change_linear_weights_to_woqtensors (Int8WeightOnlyQuantizedLinearWeight )
122- )
123- _ref_change_linear_weights_to_int4_woqtensors = (
124- _get_ref_change_linear_weights_to_woqtensors (Int4WeightOnlyQuantizedLinearWeight )
125- )
126-
127-
12868torch ._dynamo .config .cache_size_limit = 50000
12969
13070
13171@torch .no_grad
132- def _bench_quantized_tensor_subclass_perf (api , ref_api , M , N , K , kwargs = None ):
133- if kwargs is None :
134- kwargs = {}
135-
72+ def _bench_quantized_tensor_subclass_perf (api , config , M , N , K ):
13673 m = ToyLinearModel (
13774 M , N , K , has_bias = True , dtype = torch .bfloat16 , device = "cuda"
13875 ).eval ()
13976 m_bf16 = copy .deepcopy (m )
140- m_ref = copy .deepcopy (m )
14177 example_inputs = m .example_inputs ()
14278
143- api (m , ** kwargs )
144-
145- # reference
146- ref_api (m_ref , ** kwargs )
147-
148- res = m (* example_inputs )
149- ref = m_ref (* example_inputs )
150-
151- assert torch .equal (res , ref )
79+ api (m , config ) # Pass both model and config
15280
15381 # perf comparison
15482 from torchao .utils import benchmark_model
@@ -158,22 +86,17 @@ def _bench_quantized_tensor_subclass_perf(api, ref_api, M, N, K, kwargs=None):
15886 RUNS = 100
15987
16088 torch ._dynamo .reset ()
161- m_ref = torch .compile (m_ref , mode = "max-autotune" , fullgraph = True )
162- benchmark_model (m_ref , WARMUP , example_inputs )
163- ref_elapsed_time = benchmark_model (m_ref , RUNS , example_inputs )
89+ m_bf16 = torch .compile (m_bf16 , mode = "max-autotune" , fullgraph = True )
90+ benchmark_model (m_bf16 , WARMUP , example_inputs )
91+ bf16_elapsed_time = benchmark_model (m_bf16 , RUNS , example_inputs )
16492
16593 torch ._dynamo .reset ()
16694 m = torch .compile (m , mode = "max-autotune" , fullgraph = True )
16795 benchmark_model (m , WARMUP , example_inputs )
16896 elapsed_time = benchmark_model (m , RUNS , example_inputs )
16997
170- torch ._dynamo .reset ()
171- m_bf16 = torch .compile (m_bf16 , mode = "max-autotune" , fullgraph = True )
172- benchmark_model (m_bf16 , WARMUP , example_inputs )
173- bf16_elapsed_time = benchmark_model (m_bf16 , RUNS , example_inputs )
174-
17598 print (
176- f"{ (M , N , K )} : elapsed time: { elapsed_time } , ref elapsed time: { ref_elapsed_time } , bf16 elapsed time: { bf16_elapsed_time } "
99+ f"{ (M , N , K )} : elapsed time: { elapsed_time } , bf16 elapsed time: { bf16_elapsed_time } "
177100 )
178101
179102
@@ -182,24 +105,32 @@ def _bench_quantized_tensor_subclass_perf(api, ref_api, M, N, K, kwargs=None):
182105 (20 , 2048 , 2048 ),
183106 ]
184107
185- print ("_int8da_int8w_api" )
186-
108+ print ("Int8DynamicActivationInt8WeightConfig" )
187109 for M , N , K in all_shapes :
188110 _bench_quantized_tensor_subclass_perf (
189- _int8da_int8w_api , _ref_change_linear_weights_to_int8_dqtensors , M , N , K
111+ quantize_ ,
112+ Int8DynamicActivationInt8WeightConfig (),
113+ M ,
114+ N ,
115+ K ,
190116 )
191117
192- print ("_int8wo_api" )
193-
118+ print ("Int8WeightOnlyConfig" )
194119 for M , N , K in all_shapes :
195120 _bench_quantized_tensor_subclass_perf (
196- _int8wo_api , _ref_change_linear_weights_to_int8_woqtensors , M , N , K
121+ quantize_ ,
122+ Int8WeightOnlyConfig (),
123+ M ,
124+ N ,
125+ K ,
197126 )
198127
199- print ("_int4wo_api" )
200- kwargs = {"groupsize" : 32 , "version" : 1 }
201-
128+ print ("Int4WeightOnlyConfig" )
202129 for M , N , K in all_shapes :
203130 _bench_quantized_tensor_subclass_perf (
204- _int4wo_api , _ref_change_linear_weights_to_int4_woqtensors , M , N , K , kwargs
131+ quantize_ ,
132+ Int4WeightOnlyConfig (group_size = 32 ),
133+ M ,
134+ N ,
135+ K ,
205136 )
0 commit comments