1414 float8_weight_only ,
1515 float8_dynamic_activation_float8_weight ,
1616)
17+ from torchao .quantization .quant_api import (
18+ float8_static_activation_float8_weight ,
19+ )
20+ from torchao .quantization .quant_primitives import choose_qparams_affine , MappingType
1721from torchao .quantization .observer import PerTensor , PerRow
1822from torchao .float8 .float8_utils import compute_error
1923import torch
@@ -50,7 +54,7 @@ class TestAffineQuantizedFloat8Compile(InductorTestCase):
5054 @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
5155 @unittest .skipIf (not is_cuda_8_9 , "Requires GPU with compute capability >= 8.9" )
5256 @common_utils .parametrize ("dtype" , [torch .bfloat16 , torch .float32 ])
53- @common_utils .parametrize ("mode" , ["dynamic" , "weight-only" ])
57+ @common_utils .parametrize ("mode" , ["dynamic" , "weight-only" , "static" ])
5458 @common_utils .parametrize ("compile" , [True , False ])
5559 @common_utils .parametrize (
5660 "granularity" , [PerTensor (), PerRow ()] if is_H100 else [PerTensor ()]
@@ -60,45 +64,57 @@ class TestAffineQuantizedFloat8Compile(InductorTestCase):
6064 "sizes" ,
6165 [
6266 ((128 ,), 256 , 128 ),
63- ((256 ,), 512 , 256 ),
64- ((64 ,), 128 , 64 ),
6567 ((32 , 128 ), 64 , 256 ),
66- ((64 , 256 ), 512 , 128 ),
6768 ],
6869 )
6970 def test_fp8_linear_variants (
7071 self , dtype : torch .dtype , mode : str , compile : bool , sizes : Tuple , granularity
7172 ):
72- raises = (
73- isinstance (granularity , PerRow )
74- and mode == "dynamic"
75- and dtype != torch .bfloat16
76- )
77- context = (
78- nullcontext ()
79- if not raises
80- else pytest .raises (
81- AssertionError ,
82- match = "PerRow quantization only works for bfloat16 precision" ,
83- )
73+ error_message = None
74+ if isinstance (granularity , PerRow ):
75+ if mode == "dynamic" and dtype != torch .bfloat16 :
76+ error_message = "PerRow quantization only works for bfloat16 precision"
77+ elif mode == "static" :
78+ error_message = (
79+ "Static quantization only supports PerTensor granularity"
80+ )
81+
82+ error_context = (
83+ pytest .raises (AssertionError , match = error_message )
84+ if error_message
85+ else nullcontext ()
8486 )
85- with context :
87+
88+ with error_context :
8689 M , N , K = sizes
8790 input_tensor = torch .randn (* M , K , dtype = dtype , device = "cuda" )
88-
91+ # Get a "reasonable" scale for the input tensor even though
92+ # we use the same scale for multiple activations
93+ scale , _ = choose_qparams_affine (
94+ input_tensor ,
95+ MappingType .SYMMETRIC ,
96+ input_tensor .shape ,
97+ torch .float8_e4m3fn ,
98+ scale_dtype = torch .float32 ,
99+ )
89100 mode_map = {
90101 "dynamic" : partial (
91102 float8_dynamic_activation_float8_weight , granularity = granularity
92103 ),
93104 "weight-only" : float8_weight_only ,
105+ "static" : partial (
106+ float8_static_activation_float8_weight ,
107+ scale = scale ,
108+ granularity = granularity ,
109+ ),
94110 }
95111
96112 # Create a linear layer with bfloat16 dtype
97113 model = ToyLinearModel (K , N ).eval ().to (dtype ).to ("cuda" )
98114
99115 quantized_model = copy .deepcopy (model )
100116 factory = mode_map [mode ]()
101- quantize_ (model , factory )
117+ quantize_ (quantized_model , factory )
102118
103119 if compile :
104120 quantized_model = torch .compile (quantized_model , fullgraph = True )
@@ -145,14 +161,23 @@ def test_per_row_with_float32(self):
145161
146162 @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
147163 @unittest .skipIf (not is_cuda_8_9 , "Requires GPU with compute capability >= 8.9" )
148- @common_utils .parametrize ("mode" , ["dynamic" , "weight-only" ])
164+ @common_utils .parametrize ("mode" , ["dynamic" , "weight-only" , "static" ])
149165 def test_serialization (self , mode : str ):
150166 # Create and quantize the model
151167 model = ToyLinearModel (16 , 32 ).to (device = "cuda" )
152- if mode == "dynamic" :
153- factory = float8_dynamic_activation_float8_weight ()
154- else :
155- factory = float8_weight_only ()
168+
169+ mode_map = {
170+ "dynamic" : partial (
171+ float8_dynamic_activation_float8_weight , granularity = PerTensor ()
172+ ),
173+ "weight-only" : float8_weight_only ,
174+ "static" : partial (
175+ float8_static_activation_float8_weight ,
176+ scale = torch .tensor (1.0 , dtype = torch .float32 , device = "cuda" ),
177+ granularity = PerTensor (),
178+ ),
179+ }
180+ factory = mode_map [mode ]()
156181 quantize_ (model , factory )
157182
158183 # Save the state dict to an in-memory buffer
@@ -163,46 +188,50 @@ def test_serialization(self, mode: str):
163188 buffer .seek (0 )
164189
165190 # Load the state dict from the buffer
166- loaded_state_dict = torch .load (buffer )
191+ weights_only_load = True
192+ if mode == "dynamic" :
193+ # TODO will fix in followup
194+ weights_only_load = False
195+
196+ loaded_state_dict = torch .load (buffer , weights_only = weights_only_load )
167197
168198 # Create a new model and load the state dict
169199 with torch .device ("meta" ):
170200 new_model = ToyLinearModel (16 , 32 )
201+ if mode == "static" :
202+ quantize_ (new_model , factory )
171203 new_model .load_state_dict (loaded_state_dict , assign = True )
172204
173205 # Compare the original and loaded models
174- if mode == "weight-only" :
175- model_weight_1 = model .linear1 .weight .layout_tensor .float8_data .to (
176- torch .float32
177- )
178- new_model_weight_1 = new_model .linear1 .weight .layout_tensor .float8_data .to (
179- torch .float32
180- )
181-
182- model_weight_2 = model .linear2 .weight .layout_tensor .float8_data .to (
183- torch .float32
184- )
185- new_model_weight_2 = new_model .linear2 .weight .layout_tensor .float8_data .to (
186- torch .float32
187- )
188-
189- else :
190- model_weight_1 = model .linear1 .weight .original_weight_tensor .layout_tensor .float8_data .to (
191- torch .float32
192- )
193- new_model_weight_1 = new_model .linear1 .weight .original_weight_tensor .layout_tensor .float8_data .to (
194- torch .float32
195- )
196-
197- model_weight_2 = model .linear2 .weight .original_weight_tensor .layout_tensor .float8_data .to (
198- torch .float32
199- )
200- new_model_weight_2 = new_model .linear2 .weight .original_weight_tensor .layout_tensor .float8_data .to (
201- torch .float32
202- )
203-
204- assert torch .allclose (model_weight_1 , new_model_weight_1 )
205- assert torch .allclose (model_weight_2 , new_model_weight_2 )
206+ for layer_name in ["linear1" , "linear2" ]:
207+ original_layer = getattr (model , layer_name )
208+ new_layer = getattr (new_model , layer_name )
209+
210+ # Compare weights
211+ if mode == "weight-only" :
212+ original_weight = original_layer .weight .layout_tensor .float8_data .to (
213+ torch .float32
214+ )
215+ new_weight = new_layer .weight .layout_tensor .float8_data .to (
216+ torch .float32
217+ )
218+ else :
219+ original_weight = original_layer .weight .original_weight_tensor .layout_tensor .float8_data .to (
220+ torch .float32
221+ )
222+ new_weight = new_layer .weight .original_weight_tensor .layout_tensor .float8_data .to (
223+ torch .float32
224+ )
225+
226+ assert torch .allclose (
227+ original_weight , new_weight
228+ ), f"Weights do not match for { layer_name } "
229+
230+ # Compare scales
231+ if hasattr (original_layer .weight , "scale" ):
232+ assert torch .allclose (
233+ original_layer .weight .scale , new_layer .weight .scale
234+ ), f"Scales do not match for { layer_name } "
206235
207236
208237common_utils .instantiate_parametrized_tests (TestAffineQuantizedFloat8Compile )
0 commit comments