Skip to content

Commit 8ed3521

Browse files
authored
Support enable_gqa and only support 4D Q, K, and V (#2558)
1. Support `enable_gqa` 2. Align PyTorch setting to unsupport Q, K, and V when they are not 4D: https://github.com/pytorch/pytorch/blob/62843c14bbf694f5722fd6e1075da4792507fe42/torch/onnx/_internal/exporter/_torchlib/ops/nn.py#L131-L133 NOTE: torch.nn.functional.scaled_dot_product_attention actually supports 3D, and even Q-3D with K and V - 4D in op tests.
1 parent 366f7be commit 8ed3521

File tree

3 files changed

+114
-5
lines changed

3 files changed

+114
-5
lines changed

onnxscript/function_libs/torch_lib/ops/nn.py

Lines changed: 72 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1741,6 +1741,64 @@ def _attention_scale(query: TFloat) -> TFloat:
17411741
return scale
17421742

17431743

1744+
def _attention_repeat_kv_for_group_query(
1745+
query: TFloat, key: TFloat, value: TFloat
1746+
) -> Tuple[TFloat, TFloat]:
1747+
"""Expand key and value for group query attention.
1748+
1749+
repeat_interleave is applied on key and value to match the number of heads in query.
1750+
1751+
Args:
1752+
query: Tensor of shape [B, q_num_heads, q_S, E]
1753+
key: Tensor of shape [B, k_num_heads, kv_S, E]
1754+
value: Tensor of shape [B, v_num_heads, kv_S, E]
1755+
1756+
Returns:
1757+
Tuple of (expanded_key, expanded_value) where:
1758+
- expanded_key: Tensor of shape [B, q_num_heads, kv_S, E]
1759+
- expanded_value: Tensor of shape [B, q_num_heads, kv_S, E
1760+
"""
1761+
1762+
assert (
1763+
query.shape[1] > key.shape[1] == value.shape[1] and query.shape[1] % key.shape[1] == 0
1764+
), (
1765+
"SDPA (GQA or MQA) requires q_num_heads > kv_num_heads & q_num_heads % kv_num_heads == 0"
1766+
)
1767+
1768+
# NOTE: QKV are expected to be 4D tensors
1769+
1770+
batch_size = op.Shape(query, start=0, end=1) # [B]
1771+
q_num_heads = op.Shape(query, start=1, end=2) # [Hq]
1772+
kv_num_heads = op.Shape(key, start=1, end=2) # [Hk]
1773+
qk_head_size = op.Shape(key, start=3, end=4) # [Dk]
1774+
v_head_size = op.Shape(value, start=3, end=4) # [Dv]
1775+
new_kv_seq_len = op.Shape(key, start=2, end=3) # [T]
1776+
1777+
interleave_dim = op.Div(q_num_heads, kv_num_heads) # Hq / Hk
1778+
two = op.Constant(value_int=2)
1779+
k_unsqueezed = op.Unsqueeze(key, two) # [B, Hk, 1, T, Dk]
1780+
v_unsqueezed = op.Unsqueeze(value, two) # [B, Hv, 1, T, Dv]
1781+
1782+
k_expand_shape = op.Concat(
1783+
batch_size, kv_num_heads, interleave_dim, new_kv_seq_len, qk_head_size, axis=0
1784+
)
1785+
k_expand = op.Expand(k_unsqueezed, k_expand_shape)
1786+
v_expand_shape = op.Concat(
1787+
batch_size, kv_num_heads, interleave_dim, new_kv_seq_len, v_head_size, axis=0
1788+
)
1789+
v_expand = op.Expand(v_unsqueezed, v_expand_shape)
1790+
1791+
k_attention_shape = op.Concat(
1792+
batch_size, q_num_heads, new_kv_seq_len, qk_head_size, axis=0
1793+
)
1794+
v_attention_shape = op.Concat(batch_size, q_num_heads, new_kv_seq_len, v_head_size, axis=0)
1795+
1796+
expanded_key = op.Reshape(k_expand, k_attention_shape)
1797+
expanded_value = op.Reshape(v_expand, v_attention_shape)
1798+
1799+
return expanded_key, expanded_value
1800+
1801+
17441802
@torch_op("aten::scaled_dot_product_attention", trace_only=True)
17451803
def aten_scaled_dot_product_attention(
17461804
query: TFloat,
@@ -1772,8 +1830,8 @@ def aten_scaled_dot_product_attention(
17721830
"is_causal and attn_mask cannot be set at the same time"
17731831
)
17741832

1775-
assert not enable_gqa, (
1776-
"conversion of scaled_dot_product_attention not implemented if enable_gqa is True"
1833+
assert len(query.shape) == 4 and len(key.shape) == 4 and len(value.shape) == 4, (
1834+
"only 4D query, key, and value are supported"
17771835
)
17781836

17791837
# Reference: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
@@ -1784,6 +1842,13 @@ def aten_scaled_dot_product_attention(
17841842
if is_causal:
17851843
attn_mask = _causal_attention_mask(query, key)
17861844

1845+
if enable_gqa:
1846+
key, value = _attention_repeat_kv_for_group_query(query, key, value)
1847+
else:
1848+
assert query.shape[1] == key.shape[1] == value.shape[1], (
1849+
"SDPA (MHA) requires q_num_heads = kv_num_heads"
1850+
)
1851+
17871852
if attn_mask is None:
17881853
return _aten_scaled_dot_product_attention_no_mask_onnx(
17891854
query, key, value, scale, dropout_p
@@ -1981,9 +2046,8 @@ def aten_scaled_dot_product_attention_bool_mask(
19812046
assert (not is_causal) or (is_causal and attn_mask is None), (
19822047
"is_causal and attn_mask cannot be set at the same time"
19832048
)
1984-
1985-
assert not enable_gqa, (
1986-
"conversion of scaled_dot_product_attention not implemented if enable_gqa is True"
2049+
assert len(query.shape) == 4 and len(key.shape) == 4 and len(value.shape) == 4, (
2050+
"only 4D query, key, and value are supported"
19872051
)
19882052

19892053
if scale is None:
@@ -1997,6 +2061,9 @@ def aten_scaled_dot_product_attention_bool_mask(
19972061
query, key, value, attn_mask, scale, dropout_p
19982062
)
19992063

2064+
if enable_gqa:
2065+
key, value = _attention_repeat_kv_for_group_query(query, key, value)
2066+
20002067
if attn_mask is None:
20012068
return _aten_scaled_dot_product_attention_no_mask_onnx(
20022069
query, key, value, scale, dropout_p

tests/function_libs/torch_lib/e2e_ops_tests.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,36 @@ def forward(self, x):
195195
)
196196
_testing.assert_onnx_program(onnx_program)
197197

198+
def test_enable_gqa_in_attention(self):
199+
class Model(torch.nn.Module):
200+
def forward(self, q, k, v):
201+
return torch.nn.functional.scaled_dot_product_attention( # pylint: disable=not-callable
202+
q,
203+
k,
204+
v,
205+
enable_gqa=True,
206+
)
207+
208+
model = Model()
209+
210+
query = torch.randn(2, 4, 8, 16)
211+
key = torch.randn(2, 2, 8, 16)
212+
value = torch.randn(2, 2, 8, 16)
213+
214+
onnx_program = torch.onnx.export(
215+
model,
216+
(
217+
query,
218+
key,
219+
value,
220+
),
221+
input_names=["query", "key", "value"],
222+
output_names=["output"],
223+
opset_version=18,
224+
dynamo=True,
225+
)
226+
_testing.assert_onnx_program(onnx_program)
227+
198228

199229
if __name__ == "__main__":
200230
unittest.main()

tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1908,6 +1908,12 @@ def _where_input_wrangler(
19081908
dtypes=(torch.float16,),
19091909
reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438",
19101910
test_class_name="TestOutputConsistencyFullGraph",
1911+
)
1912+
.xfail(
1913+
matcher=lambda sample: len(sample.input.shape) != 4
1914+
or len(sample.args[0].shape) != 4
1915+
or len(sample.args[1].shape) != 4,
1916+
reason="torch sdpa is expected to pass in 4d q, k, and v.",
19111917
),
19121918
TorchLibOpInfo(
19131919
"ops.aten._scaled_dot_product_flash_attention",
@@ -1959,6 +1965,12 @@ def _where_input_wrangler(
19591965
dtypes=(torch.float16,),
19601966
reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438",
19611967
test_class_name="TestOutputConsistencyFullGraph",
1968+
)
1969+
.xfail(
1970+
matcher=lambda sample: len(sample.input.shape) != 4
1971+
or len(sample.args[0].shape) != 4
1972+
or len(sample.args[1].shape) != 4,
1973+
reason="torch sdpa is expected to pass in 4d q, k, and v.",
19621974
),
19631975
TorchLibOpInfo(
19641976
"ops.aten.upsample_bilinear2d.default",

0 commit comments

Comments
 (0)