Skip to content

Commit 70e751a

Browse files
authored
Implement SDPA via MHA (#2683)
Implement SDPA via MHA. This handles the case when earlier fusion rules do not map larger patterns containing SDPA into MHA or GQA or Attention (from ORT contrib ops). It implements SDPA via MHA. --------- Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
1 parent ea8cb3e commit 70e751a

File tree

4 files changed

+66
-12
lines changed

4 files changed

+66
-12
lines changed

onnxscript/rewriter/ort_fusions/_core.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
fuse_rotary_embedding,
3030
)
3131
from onnxscript.rewriter.ort_fusions.sdpa import fuse_sdpa
32+
from onnxscript.rewriter.ort_fusions.sdpa_via_mha import replace_sdpa_by_mha
3233
from onnxscript.rewriter.ort_fusions.skip_normalization import (
3334
fuse_skip_layer_normalization,
3435
fuse_skip_rms_normalization,
@@ -104,6 +105,7 @@ def fuse(func, **kwargs):
104105
fusion_count["attention"] = fuse(fuse_attention)
105106
fusion_count["gelu"] = fuse(fuse_gelu)
106107
fusion_count["bias_gelu"] = fuse(fuse_bias_gelu)
108+
fusion_count["sdpa_via_mha"] = fuse(replace_sdpa_by_mha)
107109
# Finally: inline any intermediate fusion functions introduced that were not
108110
# consumed by other fusions, and eliminate any remaining unused nodes.
109111
optimize(model)

onnxscript/rewriter/ort_fusions/sdpa.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,18 @@
1212

1313
Dim = Union[int, ir.SymbolicDim]
1414

15+
# This file contains a fusion rule that recognizes various patterns of scaled dot-product attention
16+
# (SDPA) implementations and replaces them with a single SDPA op. The SDPA op is a temporary fusion
17+
# op defined in the ai.onnxruntime._fusion domain. Subsequent fusion rules will map it into one
18+
# of the various ops defined in ORT: MHA, GQA, or Attention depending on the input patterns.
19+
# The SDPA is a standard scalar dot-product attention with an optional mask input and scaling factor.
20+
# Currently, it is restricted to query, key, and values of rank 4 with shapes:
21+
# Query: [batch_size, num_heads, seq_len, head_size_qk]
22+
# Key: [batch_size, num_heads, seq_len_kv, head_size_qk]
23+
# or [batch_size, seq_len_kv, num_heads, head_size_qk])
24+
# Value: [batch_size, num_heads, seq_len_kv, head_size_v]
25+
# The key_format attribute indicates which of the two formats the key uses and can be either "BHSd" or "BSHd".
26+
1527

1628
class SDPA(pattern.RewriteRuleClassBase):
1729
_scale: float | None

onnxscript/rewriter/ort_fusions/sdpa_test.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -292,20 +292,41 @@ def _masked_custom_scale_post_mul_sdpa_script(query, key, value, mask):
292292
return attn_output
293293

294294

295+
# This tests a scenario where the key is in BSHd format instead of BHSd, which
296+
# happens due to an optimization that fuses two transposes together, the one
297+
# to convert from BSHd to BHSd and then to BHdS before MatMul. Hence, the first
298+
# transpose down below is different from other test cases.
299+
@script()
300+
def _unmasked_pre_div_sdpa_BSHd_key_script(query, key, value):
301+
key_transposed = op.Transpose(key, perm=[0, 2, 3, 1]) # BSHd to BHdS
302+
divisor = op.Constant(value_float=SQRT_SCALE_FACTOR)
303+
scaled_query = op.Div(query, divisor)
304+
scaled_key = op.Div(key_transposed, divisor)
305+
attn_score = op.MatMul(scaled_query, scaled_key)
306+
attn_weight = op.Softmax(attn_score, axis=-1)
307+
is_nan = op.IsNaN(attn_weight)
308+
zero = op.Constant(value_float=0.0)
309+
adj_attn_weight = op.Where(is_nan, zero, attn_weight)
310+
attn_output = op.MatMul(adj_attn_weight, value)
311+
return attn_output
312+
313+
295314
class SDPATestCase:
296-
def __init__(self, script_func, *, with_mask):
315+
def __init__(self, script_func, *, with_mask, BSHd_key=False):
297316
self.script_func = script_func
298317
self.with_mask = with_mask
318+
self.BSHd_key = BSHd_key
299319

300320
def get_onnx_model(self):
301321
if not hasattr(self, "_onnx_model"):
302-
qkv_type = FLOAT[B, N, S, H]
322+
qv_type = FLOAT[B, N, S, H]
303323
mask_type = FLOAT[B, N, S, S]
304-
input_types = [qkv_type, qkv_type, qkv_type]
324+
k_type = FLOAT[B, S, N, H] if self.BSHd_key else FLOAT[B, N, S, H]
325+
input_types = [qv_type, k_type, qv_type]
305326
if self.with_mask:
306327
input_types.append(mask_type)
307328
model_proto = self.script_func.to_model_proto(
308-
input_types=input_types, output_types=[qkv_type]
329+
input_types=input_types, output_types=[qv_type]
309330
)
310331
self._onnx_model = ir.serde.deserialize_model(model_proto)
311332
return self._onnx_model
@@ -314,7 +335,9 @@ def get_ort_inputs(self):
314335
if not hasattr(self, "_ort_inputs"):
315336
inputs = {
316337
"query": numpy.random.rand(B, N, S, H).astype(numpy.float32),
317-
"key": numpy.random.rand(B, N, S, H).astype(numpy.float32),
338+
"key": numpy.random.rand(B, S, N, H).astype(numpy.float32)
339+
if self.BSHd_key
340+
else numpy.random.rand(B, N, S, H).astype(numpy.float32),
318341
"value": numpy.random.rand(B, N, S, H).astype(numpy.float32),
319342
}
320343
if self.with_mask:
@@ -374,10 +397,13 @@ class TestSDPAFusion(unittest.TestCase):
374397
"_custom_multi_scale_pre_mul_sdpa_script",
375398
_custom_multi_scale_pre_mul_sdpa_script,
376399
),
400+
("pre_div_sdpa_BSHd_key", _unmasked_pre_div_sdpa_BSHd_key_script),
377401
]
378402
)
379403
def test_sdpa_fusion(self, name, script_func):
380-
test_case = SDPATestCase(script_func, with_mask="masked" in name)
404+
test_case = SDPATestCase(
405+
script_func, with_mask="masked" in name, BSHd_key="BSHd_key" in name
406+
)
381407
model = test_case.get_onnx_model()
382408
onnxscript.optimizer.optimize(model)
383409

onnxscript/rewriter/ort_fusions/sdpa_via_mha.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,43 +7,57 @@
77
import onnx_ir as ir
88

99
from onnxscript.rewriter import _fusion_utils, pattern
10+
from onnxscript.rewriter._basics import MatchFailureError
1011

1112
Dim = Union[int, ir.SymbolicDim]
1213

1314

1415
class SDPAImplementation(pattern.RewriteRuleClassBase):
15-
def pattern(self, op, query, key, value):
16+
def pattern(self, op, query, key, value, key_format):
17+
"""Pattern matches any call to SDPA. See sdpa.py for documentation on the SDPA op."""
1618
return op.SDPA(
1719
query,
1820
key,
1921
value,
20-
key_format="BHSd",
22+
key_format=key_format,
2123
_allow_other_inputs=True, # Mask is optional
2224
_outputs=["sdpa_output"],
2325
_domain="ai.onnxruntime._fusion",
2426
)
2527

26-
def check(self, context, query, key, value, sdpa_output):
28+
def check(self, context, query, key, value, key_format, sdpa_output):
2729
bindings: dict[str, Dim] = {}
2830
_fusion_utils.check_shape(bindings, query, ["B", "H", "S", "Dh"])
29-
_fusion_utils.check_shape(bindings, key, ["B", "H", "Skv", "Dh"])
3031
_fusion_utils.check_shape(bindings, value, ["B", "H", "Skv", "Dv"])
3132

33+
if key_format.value == "BHSd":
34+
_fusion_utils.check_shape(bindings, key, ["B", "H", "Skv", "Dh"])
35+
elif key_format.value == "BSHd":
36+
_fusion_utils.check_shape(bindings, key, ["B", "Skv", "H", "Dh"])
37+
else:
38+
raise MatchFailureError(
39+
f"Unexpected key_format value: {key_format.value}", key_format
40+
)
41+
3242
self._num_heads = bindings["H"]
3343
if not isinstance(self._num_heads, int):
3444
return False
3545
self._use_mask_broadcast = True # TODO: optimize to avoid broadcast if not needed
3646
return isinstance(self._num_heads, int)
3747

38-
def rewrite(self, op, query, key, value, sdpa_output):
48+
def rewrite(self, op, query, key, value, key_format, sdpa_output):
3949
sdpa_node = sdpa_output.producer()
4050
scale = sdpa_node.attributes.get("scale", None)
4151
to_3d_shape = op.Constant(value_ints=[0, 0, -1])
4252
to_4d_shape = op.Constant(value_ints=[0, 0, self._num_heads, -1])
4353
query_3d = op.Reshape(op.Transpose(query, perm=[0, 2, 1, 3]), to_3d_shape)
44-
key_3d = op.Reshape(op.Transpose(key, perm=[0, 2, 1, 3]), to_3d_shape)
4554
value_3d = op.Reshape(op.Transpose(value, perm=[0, 2, 1, 3]), to_3d_shape)
4655

56+
if key_format.value == "BHSd":
57+
key_3d = op.Reshape(op.Transpose(key, perm=[0, 2, 1, 3]), to_3d_shape)
58+
else: # BSHd
59+
key_3d = op.Reshape(key, to_3d_shape)
60+
4761
inputs = [query_3d, key_3d, value_3d]
4862
if len(sdpa_node.inputs) > 3:
4963
mask = sdpa_node.inputs[3]

0 commit comments

Comments
 (0)