@@ -292,20 +292,41 @@ def _masked_custom_scale_post_mul_sdpa_script(query, key, value, mask):
292292 return attn_output
293293
294294
295+ # This tests a scenario where the key is in BSHd format instead of BHSd, which
296+ # happens due to an optimization that fuses two transposes together, the one
297+ # to convert from BSHd to BHSd and then to BHdS before MatMul. Hence, the first
298+ # transpose down below is different from other test cases.
299+ @script ()
300+ def _unmasked_pre_div_sdpa_BSHd_key_script (query , key , value ):
301+ key_transposed = op .Transpose (key , perm = [0 , 2 , 3 , 1 ]) # BSHd to BHdS
302+ divisor = op .Constant (value_float = SQRT_SCALE_FACTOR )
303+ scaled_query = op .Div (query , divisor )
304+ scaled_key = op .Div (key_transposed , divisor )
305+ attn_score = op .MatMul (scaled_query , scaled_key )
306+ attn_weight = op .Softmax (attn_score , axis = - 1 )
307+ is_nan = op .IsNaN (attn_weight )
308+ zero = op .Constant (value_float = 0.0 )
309+ adj_attn_weight = op .Where (is_nan , zero , attn_weight )
310+ attn_output = op .MatMul (adj_attn_weight , value )
311+ return attn_output
312+
313+
295314class SDPATestCase :
296- def __init__ (self , script_func , * , with_mask ):
315+ def __init__ (self , script_func , * , with_mask , BSHd_key = False ):
297316 self .script_func = script_func
298317 self .with_mask = with_mask
318+ self .BSHd_key = BSHd_key
299319
300320 def get_onnx_model (self ):
301321 if not hasattr (self , "_onnx_model" ):
302- qkv_type = FLOAT [B , N , S , H ]
322+ qv_type = FLOAT [B , N , S , H ]
303323 mask_type = FLOAT [B , N , S , S ]
304- input_types = [qkv_type , qkv_type , qkv_type ]
324+ k_type = FLOAT [B , S , N , H ] if self .BSHd_key else FLOAT [B , N , S , H ]
325+ input_types = [qv_type , k_type , qv_type ]
305326 if self .with_mask :
306327 input_types .append (mask_type )
307328 model_proto = self .script_func .to_model_proto (
308- input_types = input_types , output_types = [qkv_type ]
329+ input_types = input_types , output_types = [qv_type ]
309330 )
310331 self ._onnx_model = ir .serde .deserialize_model (model_proto )
311332 return self ._onnx_model
@@ -314,7 +335,9 @@ def get_ort_inputs(self):
314335 if not hasattr (self , "_ort_inputs" ):
315336 inputs = {
316337 "query" : numpy .random .rand (B , N , S , H ).astype (numpy .float32 ),
317- "key" : numpy .random .rand (B , N , S , H ).astype (numpy .float32 ),
338+ "key" : numpy .random .rand (B , S , N , H ).astype (numpy .float32 )
339+ if self .BSHd_key
340+ else numpy .random .rand (B , N , S , H ).astype (numpy .float32 ),
318341 "value" : numpy .random .rand (B , N , S , H ).astype (numpy .float32 ),
319342 }
320343 if self .with_mask :
@@ -374,10 +397,13 @@ class TestSDPAFusion(unittest.TestCase):
374397 "_custom_multi_scale_pre_mul_sdpa_script" ,
375398 _custom_multi_scale_pre_mul_sdpa_script ,
376399 ),
400+ ("pre_div_sdpa_BSHd_key" , _unmasked_pre_div_sdpa_BSHd_key_script ),
377401 ]
378402 )
379403 def test_sdpa_fusion (self , name , script_func ):
380- test_case = SDPATestCase (script_func , with_mask = "masked" in name )
404+ test_case = SDPATestCase (
405+ script_func , with_mask = "masked" in name , BSHd_key = "BSHd_key" in name
406+ )
381407 model = test_case .get_onnx_model ()
382408 onnxscript .optimizer .optimize (model )
383409
0 commit comments