1818from torchao .utils import torch_version_at_least
1919
2020
21- class SelfAttnLikeModule (torch .nn .Module ):
21+ def qdq (input , scale ):
22+ dtype = input .dtype
23+ q_input = torch .ops .torchao .quantize_affine_float8_non_decomposed .default (
24+ input ,
25+ torch .tensor ([scale ]),
26+ torch .float8_e4m3fn ,
27+ )
28+ dq_input = torch .ops .torchao .dequantize_affine_float8_non_decomposed .default (
29+ q_input ,
30+ torch .tensor ([scale ]),
31+ dtype ,
32+ )
33+ return dq_input
34+
35+
36+ def fp8_convert_ (model ):
37+ def generate_model_info (model ):
38+ from collections import namedtuple
39+
40+ mod_inst_info = namedtuple ("ModInstInfo" , ["name" , "parent" ])
41+ parent_child_mod_dict = {}
42+
43+ def create_mod_info_recursion (parent ):
44+ for name , mod in parent .named_children ():
45+ parent_child_mod_dict [mod ] = mod_inst_info (name = name , parent = parent )
46+ create_mod_info_recursion (mod )
47+
48+ create_mod_info_recursion (model )
49+ return parent_child_mod_dict
50+
51+ parent_child_mod_dict = generate_model_info (model )
52+ for name , mod in model .named_modules ():
53+ mod_type_str = mod .__class__ .__name__
54+ if mod_type_str not in [
55+ "Linear" ,
56+ "SDPA" ,
57+ ]:
58+ continue
59+ if mod_type_str == "Linear" :
60+ param = mod .weight
61+ xmax = torch .max (param )
62+ weight_scale = xmax / torch .finfo (torch .float8_e4m3fn ).max
63+ mod .weight_scale = weight_scale
64+ q_param = torch .clamp (
65+ (param / weight_scale ),
66+ torch .finfo (torch .float8_e4m3fn ).min ,
67+ torch .finfo (torch .float8_e4m3fn ).max ,
68+ ).to (torch .float8_e4m3fn )
69+ mod .weight .data = q_param
70+ patched_mod = FP8QDQLinear (mod .in_features , mod .out_features , False )
71+ patched_mod .bias = mod .bias
72+ patched_mod .weight_scale = weight_scale .item ()
73+ patched_mod .weight .data = q_param
74+ else :
75+ patched_mod = FP8QDQSDPA ()
76+ patched_mod .__dict__ .update (mod .__dict__ )
77+ patched_mod .transpose_for_scores = mod .transpose_for_scores
78+
79+ patched_mod .q_out_scale = (
80+ patched_mod .q_out_scale / torch .finfo (torch .float8_e4m3fn ).max
81+ )
82+ patched_mod .k_out_scale = (
83+ patched_mod .k_out_scale / torch .finfo (torch .float8_e4m3fn ).max
84+ )
85+ patched_mod .attn_weights_scale = (
86+ patched_mod .attn_weights_scale / torch .finfo (torch .float8_e4m3fn ).max
87+ )
88+ patched_mod .v_out_scale = (
89+ patched_mod .v_out_scale / torch .finfo (torch .float8_e4m3fn ).max
90+ )
91+ patched_mod .qk_out_scale = (
92+ patched_mod .qk_out_scale / torch .finfo (torch .float8_e4m3fn ).max
93+ )
94+ patched_mod .attn_out_scale = (
95+ patched_mod .attn_out_scale / torch .finfo (torch .float8_e4m3fn ).max
96+ )
97+
98+ parent = parent_child_mod_dict [mod ].parent
99+ name = parent_child_mod_dict [mod ].name
100+ setattr (parent , name , patched_mod )
101+ model .eval ()
102+ return model
103+
104+
105+ class FP8QDQLinear (torch .nn .Module ):
106+ def __init__ (self , in_features , out_features , has_bias ):
107+ super ().__init__ ()
108+ self .qtype = torch .float8_e4m3fn
109+ self .weight = torch .randn ((out_features , in_features )).to (self .qtype )
110+ self .weight_scale = 2.0
111+ self .scale = 2.0
112+ self .bias = None
113+ if has_bias :
114+ self .bias = torch .randn ((out_features ,))
115+
116+ def forward (self , input ):
117+ weight = torch .ops .torchao .dequantize_affine_float8_non_decomposed .default (
118+ tensor = self .weight .data ,
119+ scale = torch .tensor ([self .weight_scale ]),
120+ output_dtype = torch .float ,
121+ )
122+
123+ q_input = torch .ops .torchao .quantize_affine_float8_non_decomposed .default (
124+ tensor = input ,
125+ scale = torch .tensor ([self .scale ]),
126+ float8_dtype = self .qtype ,
127+ )
128+ dq_input = torch .ops .torchao .dequantize_affine_float8_non_decomposed .default (
129+ tensor = q_input ,
130+ scale = torch .tensor ([self .scale ]),
131+ output_dtype = torch .float ,
132+ )
133+
134+ out = torch .nn .functional .linear (dq_input , weight , self .bias )
135+ return out
136+
137+
138+ class FP8QDQSDPA (torch .nn .Module ):
139+ def __init__ (self ):
140+ super ().__init__ ()
141+ self .q_out_scale = 1.5
142+ self .k_out_scale = 1.5
143+ self .attn_weights_scale = 1.5
144+ self .v_out_scale = 1.5
145+ self .attn_out_scale = 1.5
146+ self .qk_out_scale = 1.5
147+
148+ def forward (self , q , k , v , mask ):
149+ key = self .transpose_for_scores (q )
150+ value = self .transpose_for_scores (k )
151+ query = self .transpose_for_scores (v )
152+
153+ # Take the dot product between "query" and "key" to get the raw attention scores.
154+ query_qdq = qdq (query , self .q_out_scale )
155+ key_qdq = qdq (key .transpose (- 1 , - 2 ), self .k_out_scale )
156+ attn_weights = torch .matmul (query_qdq , key_qdq ) / (self .input_dim ** 0.5 )
157+
158+ # Normalize the attention scores to probabilities.
159+ attn_weights = torch .nn .functional .softmax (
160+ attn_weights , dim = - 1 , dtype = torch .float32
161+ ).to (query .dtype )
162+
163+ # This is actually dropping out entire tokens to attend to, which might
164+ # seem a bit unusual, but is taken from the original Transformer paper.
165+ dropout = 0.0 if not self .training else self .dropout_prob
166+ attn_weights = torch .nn .functional .dropout (
167+ attn_weights , p = dropout , training = self .training
168+ )
169+
170+ # Mask heads if we want to
171+ if mask is not None :
172+ attn_weights = attn_weights + mask
173+
174+ value_qdq = qdq (value , self .v_out_scale )
175+ attn_weights_qdq = qdq (attn_weights , self .attn_weights_scale )
176+ attn_output = torch .matmul (attn_weights_qdq , value_qdq )
177+ attn_output = attn_output .transpose (1 , 2 ).contiguous ()
178+
179+ new_context_layer_shape = attn_output .size ()[:- 2 ] + (self .all_head_size ,)
180+ attn_output = attn_output .reshape (new_context_layer_shape )
181+
182+ return attn_output
183+
184+
185+ class SDPA (torch .nn .Module ):
22186 def __init__ (
23187 self ,
24188 input_dim ,
25189 has_mask ,
26- num_attention_heads = None ,
27- attention_head_size = None ,
190+ num_attention_heads ,
191+ attention_head_size ,
28192 ) -> None :
29193 super ().__init__ ()
30194 self .input_dim = input_dim
31- self .q_proj = torch .nn .Linear (input_dim , input_dim , bias = False )
32- self .k_proj = torch .nn .Linear (input_dim , input_dim , bias = False )
33- self .v_proj = torch .nn .Linear (input_dim , input_dim , bias = False )
34195 self .softmax = torch .nn .Softmax (dim = - 1 )
35- assert num_attention_heads is not None
36- assert attention_head_size is not None
37196 self .num_attention_heads = num_attention_heads
38197 self .attention_head_size = attention_head_size
39198 self .all_head_size = self .num_attention_heads * self .attention_head_size
40- self .dense = torch .nn .Linear (self .all_head_size , self .all_head_size )
41199 self .dropout = torch .nn .Dropout (0 )
42200 self .has_mask = has_mask
43201
@@ -49,10 +207,7 @@ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
49207 x = x .view (new_x_shape )
50208 return x .permute ([0 , 2 , 1 , 3 ])
51209
52- def forward (self , x , mask ):
53- q = self .q_proj (x )
54- k = self .k_proj (x )
55- v = self .v_proj (x )
210+ def forward (self , q , k , v , mask ):
56211 q = self .transpose_for_scores (q )
57212 k = self .transpose_for_scores (k )
58213 v = self .transpose_for_scores (v )
@@ -63,9 +218,38 @@ def forward(self, x, mask):
63218 attention = self .dropout (attention )
64219 context_layer = torch .matmul (attention , v )
65220 context_layer = context_layer .permute (0 , 2 , 1 , 3 ).contiguous ()
66- context_layer = context_layer .view (
67- context_layer .size ()[:- 2 ] + (self .all_head_size ,)
221+ return context_layer .reshape (context_layer .size ()[:- 2 ] + (self .all_head_size ,))
222+
223+
224+ class MHAModule (torch .nn .Module ):
225+ def __init__ (
226+ self ,
227+ input_dim ,
228+ has_mask ,
229+ num_attention_heads ,
230+ attention_head_size ,
231+ ) -> None :
232+ super ().__init__ ()
233+ self .input_dim = input_dim
234+ self .q_proj = torch .nn .Linear (input_dim , input_dim , bias = False )
235+ self .k_proj = torch .nn .Linear (input_dim , input_dim , bias = False )
236+ self .v_proj = torch .nn .Linear (input_dim , input_dim , bias = False )
237+ self .num_attention_heads = num_attention_heads
238+ self .attention_head_size = attention_head_size
239+ self .all_head_size = self .num_attention_heads * self .attention_head_size
240+ self .dense = torch .nn .Linear (self .all_head_size , self .all_head_size )
241+ self .attn_mod = SDPA (
242+ input_dim ,
243+ has_mask ,
244+ num_attention_heads ,
245+ attention_head_size ,
68246 )
247+
248+ def forward (self , x , mask ):
249+ q = self .q_proj (x )
250+ k = self .k_proj (x )
251+ v = self .v_proj (x )
252+ context_layer = self .attn_mod (q , k , v , mask )
69253 return self .dense (context_layer )
70254
71255
@@ -158,7 +342,7 @@ def _check_common(
158342 reason = "cpp kernels not built" ,
159343 )
160344 @config .patch ({"freezing" : True })
161- def _test_qsdpa_rewriter (self ):
345+ def _test_int8_sdpa_rewriter (self ):
162346 import torchao .quantization .pt2e .quantizer .x86_inductor_quantizer as xiq
163347 from torchao .quantization .pt2e .quantize_pt2e import convert_pt2e , prepare_pt2e
164348 from torchao .quantization .pt2e .quantizer .x86_inductor_quantizer import (
@@ -171,7 +355,7 @@ def _test_qsdpa_rewriter(self):
171355 [torch .float32 , torch .bfloat16 ], [True , False ], [56 , 1 ]
172356 ):
173357 seqlen , numhead , headsize = 197 , 16 , 64
174- mod = SelfAttnLikeModule (
358+ mod = MHAModule (
175359 input_dim = headsize * numhead ,
176360 has_mask = has_mask ,
177361 num_attention_heads = numhead ,
@@ -204,6 +388,51 @@ def _test_qsdpa_rewriter(self):
204388 prepare_model (* inputs )
205389 convert_model = convert_pt2e (prepare_model )
206390 torchao .quantization .pt2e .move_exported_model_to_eval (convert_model )
391+
392+ self ._check_common (
393+ convert_model , args1 = inputs , check_train = False , atol = 1.0
394+ )
395+
396+ @skipIfRocm
397+ @unittest .skipIf (
398+ not torch_version_at_least ("2.7.0" ),
399+ reason = "qsdpa requires torch 2.7 or later" ,
400+ )
401+ @unittest .skipIf (
402+ "CPU" not in torch ._C ._dispatch_dump ("torchao::qscaled_dot_product" ),
403+ reason = "cpp kernels not built" ,
404+ )
405+ @config .patch ({"freezing" : True })
406+ def _test_fp8_sdpa_rewriter (self ):
407+ import torchao .quantization .pt2e .quantizer .x86_inductor_quantizer as xiq # noqa: F401
408+
409+ # pattern is different for bs=1
410+ torch .manual_seed (1234 )
411+ for dtype , bs in itertools .product ([torch .float32 , torch .bfloat16 ], [56 , 1 ]):
412+ seqlen , numhead , headsize = 197 , 16 , 64
413+ mod = MHAModule (
414+ input_dim = headsize * numhead ,
415+ has_mask = False ,
416+ num_attention_heads = numhead ,
417+ attention_head_size = headsize ,
418+ ).eval ()
419+ inputs = (
420+ torch .randn (
421+ (bs , seqlen , headsize * numhead ), device = self .device , dtype = dtype
422+ ),
423+ None ,
424+ )
425+ enable_autocast = dtype == torch .bfloat16
426+ with (
427+ torch .no_grad (),
428+ torch .amp .autocast (
429+ self .device , enabled = enable_autocast , dtype = torch .bfloat16
430+ ),
431+ config .patch (post_grad_custom_pre_pass = custom_pass ),
432+ ):
433+ _qsdpa_init ()
434+ convert_model = fp8_convert_ (mod )
435+
207436 self ._check_common (
208437 convert_model , args1 = inputs , check_train = False , atol = 1.0
209438 )
@@ -213,7 +442,12 @@ def _test_qsdpa_rewriter(self):
213442
214443 class SDPAPatternRewriterCpuTests (TestSDPAPatternRewriterTemplate ):
215444 device = "cpu"
216- test_qsdpa_rewriter_cpu = TestSDPAPatternRewriterTemplate ._test_qsdpa_rewriter
445+ test_int8_sdpa_rewriter_cpu = (
446+ TestSDPAPatternRewriterTemplate ._test_int8_sdpa_rewriter
447+ )
448+ test_fp8_sdpa_rewriter_cpu = (
449+ TestSDPAPatternRewriterTemplate ._test_fp8_sdpa_rewriter
450+ )
217451
218452
219453if __name__ == "__main__" :
0 commit comments