@@ -1741,6 +1741,64 @@ def _attention_scale(query: TFloat) -> TFloat:
17411741 return scale
17421742
17431743
1744+ def _attention_repeat_kv_for_group_query (
1745+ query : TFloat , key : TFloat , value : TFloat
1746+ ) -> Tuple [TFloat , TFloat ]:
1747+ """Expand key and value for group query attention.
1748+
1749+ repeat_interleave is applied on key and value to match the number of heads in query.
1750+
1751+ Args:
1752+ query: Tensor of shape [B, q_num_heads, q_S, E]
1753+ key: Tensor of shape [B, k_num_heads, kv_S, E]
1754+ value: Tensor of shape [B, v_num_heads, kv_S, E]
1755+
1756+ Returns:
1757+ Tuple of (expanded_key, expanded_value) where:
1758+ - expanded_key: Tensor of shape [B, q_num_heads, kv_S, E]
1759+ - expanded_value: Tensor of shape [B, q_num_heads, kv_S, E
1760+ """
1761+
1762+ assert (
1763+ query .shape [1 ] > key .shape [1 ] == value .shape [1 ] and query .shape [1 ] % key .shape [1 ] == 0
1764+ ), (
1765+ "SDPA (GQA or MQA) requires q_num_heads > kv_num_heads & q_num_heads % kv_num_heads == 0"
1766+ )
1767+
1768+ # NOTE: QKV are expected to be 4D tensors
1769+
1770+ batch_size = op .Shape (query , start = 0 , end = 1 ) # [B]
1771+ q_num_heads = op .Shape (query , start = 1 , end = 2 ) # [Hq]
1772+ kv_num_heads = op .Shape (key , start = 1 , end = 2 ) # [Hk]
1773+ qk_head_size = op .Shape (key , start = 3 , end = 4 ) # [Dk]
1774+ v_head_size = op .Shape (value , start = 3 , end = 4 ) # [Dv]
1775+ new_kv_seq_len = op .Shape (key , start = 2 , end = 3 ) # [T]
1776+
1777+ interleave_dim = op .Div (q_num_heads , kv_num_heads ) # Hq / Hk
1778+ two = op .Constant (value_int = 2 )
1779+ k_unsqueezed = op .Unsqueeze (key , two ) # [B, Hk, 1, T, Dk]
1780+ v_unsqueezed = op .Unsqueeze (value , two ) # [B, Hv, 1, T, Dv]
1781+
1782+ k_expand_shape = op .Concat (
1783+ batch_size , kv_num_heads , interleave_dim , new_kv_seq_len , qk_head_size , axis = 0
1784+ )
1785+ k_expand = op .Expand (k_unsqueezed , k_expand_shape )
1786+ v_expand_shape = op .Concat (
1787+ batch_size , kv_num_heads , interleave_dim , new_kv_seq_len , v_head_size , axis = 0
1788+ )
1789+ v_expand = op .Expand (v_unsqueezed , v_expand_shape )
1790+
1791+ k_attention_shape = op .Concat (
1792+ batch_size , q_num_heads , new_kv_seq_len , qk_head_size , axis = 0
1793+ )
1794+ v_attention_shape = op .Concat (batch_size , q_num_heads , new_kv_seq_len , v_head_size , axis = 0 )
1795+
1796+ expanded_key = op .Reshape (k_expand , k_attention_shape )
1797+ expanded_value = op .Reshape (v_expand , v_attention_shape )
1798+
1799+ return expanded_key , expanded_value
1800+
1801+
17441802@torch_op ("aten::scaled_dot_product_attention" , trace_only = True )
17451803def aten_scaled_dot_product_attention (
17461804 query : TFloat ,
@@ -1772,8 +1830,8 @@ def aten_scaled_dot_product_attention(
17721830 "is_causal and attn_mask cannot be set at the same time"
17731831 )
17741832
1775- assert not enable_gqa , (
1776- "conversion of scaled_dot_product_attention not implemented if enable_gqa is True "
1833+ assert len ( query . shape ) == 4 and len ( key . shape ) == 4 and len ( value . shape ) == 4 , (
1834+ "only 4D query, key, and value are supported "
17771835 )
17781836
17791837 # Reference: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
@@ -1784,6 +1842,13 @@ def aten_scaled_dot_product_attention(
17841842 if is_causal :
17851843 attn_mask = _causal_attention_mask (query , key )
17861844
1845+ if enable_gqa :
1846+ key , value = _attention_repeat_kv_for_group_query (query , key , value )
1847+ else :
1848+ assert query .shape [1 ] == key .shape [1 ] == value .shape [1 ], (
1849+ "SDPA (MHA) requires q_num_heads = kv_num_heads"
1850+ )
1851+
17871852 if attn_mask is None :
17881853 return _aten_scaled_dot_product_attention_no_mask_onnx (
17891854 query , key , value , scale , dropout_p
@@ -1981,9 +2046,8 @@ def aten_scaled_dot_product_attention_bool_mask(
19812046 assert (not is_causal ) or (is_causal and attn_mask is None ), (
19822047 "is_causal and attn_mask cannot be set at the same time"
19832048 )
1984-
1985- assert not enable_gqa , (
1986- "conversion of scaled_dot_product_attention not implemented if enable_gqa is True"
2049+ assert len (query .shape ) == 4 and len (key .shape ) == 4 and len (value .shape ) == 4 , (
2050+ "only 4D query, key, and value are supported"
19872051 )
19882052
19892053 if scale is None :
@@ -1997,6 +2061,9 @@ def aten_scaled_dot_product_attention_bool_mask(
19972061 query , key , value , attn_mask , scale , dropout_p
19982062 )
19992063
2064+ if enable_gqa :
2065+ key , value = _attention_repeat_kv_for_group_query (query , key , value )
2066+
20002067 if attn_mask is None :
20012068 return _aten_scaled_dot_product_attention_no_mask_onnx (
20022069 query , key , value , scale , dropout_p
0 commit comments