@@ -463,11 +463,13 @@ def scaled_dot_product_attention_decomposition(
463463) -> torch .Tensor :
464464 L , S = query .size (- 2 ), key .size (- 2 )
465465 device = query .device
466- attn_bias = torch .zeros (L , S , dtype = query .dtype , device = device )
466+
467+ if is_causal or attn_mask is not None :
468+ attn_bias = torch .zeros ((L , S ), dtype = query .dtype , device = device )
467469
468470 if is_causal :
469471 assert attn_mask is None , "attn_mask must be None when is_causal=True"
470- temp_mask = torch .ones (L , S , dtype = torch .bool , device = device ).tril (diagonal = 0 )
472+ temp_mask = torch .ones (( L , S ) , dtype = torch .bool , device = device ).tril (diagonal = 0 )
471473 attn_bias = attn_bias .masked_fill (temp_mask .logical_not (), float ("-inf" ))
472474
473475 if attn_mask is not None :
@@ -480,7 +482,7 @@ def scaled_dot_product_attention_decomposition(
480482 key = key .repeat_interleave (query .size (- 3 ) // key .size (- 3 ), - 3 )
481483 value = value .repeat_interleave (query .size (- 3 ) // value .size (- 3 ), - 3 )
482484
483- attn_weight = query @ key .transpose (- 2 , - 1 )
485+ attn_weight = torch . matmul ( query , key .transpose (- 2 , - 1 ) )
484486
485487 if scale is None :
486488 scale = torch .sqrt (torch .scalar_tensor (query .size (- 1 ), dtype = torch .int )).to (
@@ -490,9 +492,12 @@ def scaled_dot_product_attention_decomposition(
490492 else :
491493 attn_weight = attn_weight * scale
492494
493- attn_weight = attn_weight + attn_bias
495+ if is_causal or attn_mask is not None :
496+ # We only add attn_bias when we have to, otherwise this will have a negative impact on the performance even it's 0.
497+ attn_weight = attn_weight + attn_bias
498+
494499 attn_weight = torch .softmax (attn_weight , dim = - 1 )
495- return attn_weight @ value
500+ return torch . matmul ( attn_weight , value )
496501
497502
498503@register_torch_trt_decomposition (
0 commit comments