From 86b1687113fe7fa69e427340bdabdb0839d091d6 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 14 Oct 2025 13:42:40 -0700 Subject: [PATCH 1/4] WIP Signed-off-by: Justin Chu --- .../rules/fusion/_attention_present_kv.py | 111 ++++++++++++++++++ 1 file changed, 111 insertions(+) create mode 100644 onnxscript/rewriter/rules/fusion/_attention_present_kv.py diff --git a/onnxscript/rewriter/rules/fusion/_attention_present_kv.py b/onnxscript/rewriter/rules/fusion/_attention_present_kv.py new file mode 100644 index 0000000000..7bd638b212 --- /dev/null +++ b/onnxscript/rewriter/rules/fusion/_attention_present_kv.py @@ -0,0 +1,111 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +from typing import Union + +import onnx_ir as ir + +import onnxscript.rewriter._fusion_utils as _fusion_utils +from onnxscript.rewriter import _basics, pattern + +Dim = Union[int, ir.SymbolicDim] + + +class AttentionPresentKeyValue(pattern.RewriteRuleClassBase): + """Move present_key and present_value to be generated by Attention.""" + def pattern( + self, + op, + query_BHSD, + key_BHkvSD, + value_BHkvSD, + past_key_BHkvSpD, + past_value_BHkvSpD, + ): + # Concatenate past_key cache and current key, expand across heads + # that share key/value. + + present_key_BHkvStD = op.Concat(past_key_BHkvSpD, key_BHkvSD, axis=-2) + present_key_BHkv1StD = op.Unsqueeze(present_key_BHkvStD, 2) + present_key_BHkvGStD = op.Expand(present_key_BHkv1StD, pattern.ANY_VALUE) + present_key_BHStD = op.Reshape( + present_key_BHkvGStD, pattern.ANY_VALUE, _outputs=["present_key_BHStD"] + ) + + # Concatenate past_value cache and current value, expand across heads + # that share key/value. + present_value_BHkvStD = op.Concat(past_value_BHkvSpD, value_BHkvSD, axis=-2) + present_value_BHkv1StD = op.Unsqueeze(present_value_BHkvStD, 2) + present_value_BHkvGStD = op.Expand(present_value_BHkv1StD, pattern.ANY_VALUE) + present_value_BHStD = op.Reshape( + present_value_BHkvGStD, pattern.ANY_VALUE, _outputs=["present_value_BHStD"] + ) + + attention_BHSDh = op.Attention( + query_BHSD, + present_key_BHStD, + present_value_BHStD, + pattern.Var("mask", can_match_none=True), + _outputs=["attention_BHSDh"], + ) + + return attention_BHSDh + + def check( + self, + context: _basics.MatchContext, + query_BHSD, + key_BHkvSD, + value_BHkvSD, + past_key_BHkvSpD, + past_value_BHkvSpD, + present_key_BHStD, + present_value_BHStD, + **_, + ): + bindings: dict[str, Dim] = {} + # Check that inputs to new Attention node have expected shapes + _fusion_utils.check_shape(bindings, query_BHSD, ["B", "H", "S", "D"]) + _fusion_utils.check_shape(bindings, key_BHkvSD, ["B", "Hkv", "S", "D"]) + _fusion_utils.check_shape(bindings, value_BHkvSD, ["B", "Hkv", "S", "D"]) + _fusion_utils.check_shape(bindings, past_key_BHkvSpD, ["B", "Hkv", "P", "D"]) + _fusion_utils.check_shape(bindings, past_value_BHkvSpD, ["B", "Hkv", "P", "D"]) + # We need to check that the Expand/Reshape arguments are as expected. + # As a substitute, we check that the outputs of Expand=>Reshape have expected shapes. + # TODO (rama): May be better to check the actual Expand/Reshape arguments. + _fusion_utils.check_shape(bindings, present_key_BHStD, ["B", "H", "S+P", "D"]) + _fusion_utils.check_shape(bindings, present_value_BHStD, ["B", "H", "S+P", "D"]) + + return True + + def rewrite( + self, + op, + query_BHSD, + key_BHkvSD, + value_BHkvSD, + past_key_BHkvSpD, + past_value_BHkvSpD, + mask, + attention_BHSDh, + **_, + ): + original_attention_node = attention_BHSDh.producer() + original_attrs = original_attention_node.attributes + return op.Attention( + query_BHSD, + key_BHkvSD, + value_BHkvSD, + mask, + past_key_BHkvSpD, + past_value_BHkvSpD, + **original_attrs, + ) + + +_basic_gqa_rule = + +gqa_rules = pattern.RewriteRuleSet([AttentionPresentKeyValue.rule()]) + +fuse_gqa = _fusion_utils.apply_fusion_rules(gqa_rules) From 96ea92b6d06c39cf95da86e95b03140fea633d3d Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 14 Oct 2025 15:52:48 -0700 Subject: [PATCH 2/4] Create fusion rule for attention kv Signed-off-by: Justin Chu --- .../rules/fusion/_attention_present_kv.py | 108 +++++------------- 1 file changed, 27 insertions(+), 81 deletions(-) diff --git a/onnxscript/rewriter/rules/fusion/_attention_present_kv.py b/onnxscript/rewriter/rules/fusion/_attention_present_kv.py index 7bd638b212..dbcf438e42 100644 --- a/onnxscript/rewriter/rules/fusion/_attention_present_kv.py +++ b/onnxscript/rewriter/rules/fusion/_attention_present_kv.py @@ -2,110 +2,56 @@ # Licensed under the MIT License. from __future__ import annotations -from typing import Union - import onnx_ir as ir import onnxscript.rewriter._fusion_utils as _fusion_utils -from onnxscript.rewriter import _basics, pattern - -Dim = Union[int, ir.SymbolicDim] +from onnxscript.rewriter import pattern class AttentionPresentKeyValue(pattern.RewriteRuleClassBase): """Move present_key and present_value to be generated by Attention.""" + def pattern( self, op, - query_BHSD, - key_BHkvSD, - value_BHkvSD, - past_key_BHkvSpD, - past_value_BHkvSpD, + query, + key, + value, + mask, + past_key, + past_value, ): - # Concatenate past_key cache and current key, expand across heads - # that share key/value. - - present_key_BHkvStD = op.Concat(past_key_BHkvSpD, key_BHkvSD, axis=-2) - present_key_BHkv1StD = op.Unsqueeze(present_key_BHkvStD, 2) - present_key_BHkvGStD = op.Expand(present_key_BHkv1StD, pattern.ANY_VALUE) - present_key_BHStD = op.Reshape( - present_key_BHkvGStD, pattern.ANY_VALUE, _outputs=["present_key_BHStD"] - ) - - # Concatenate past_value cache and current value, expand across heads - # that share key/value. - present_value_BHkvStD = op.Concat(past_value_BHkvSpD, value_BHkvSD, axis=-2) - present_value_BHkv1StD = op.Unsqueeze(present_value_BHkvStD, 2) - present_value_BHkvGStD = op.Expand(present_value_BHkv1StD, pattern.ANY_VALUE) - present_value_BHStD = op.Reshape( - present_value_BHkvGStD, pattern.ANY_VALUE, _outputs=["present_value_BHStD"] - ) + present_key = op.Concat(past_key, key, axis=-2) + present_value = op.Concat(past_value, value, axis=-2) - attention_BHSDh = op.Attention( - query_BHSD, - present_key_BHStD, - present_value_BHStD, - pattern.Var("mask", can_match_none=True), - _outputs=["attention_BHSDh"], + attention_out = op.Attention( + query, key, value, mask, past_key, past_value, _outputs=["attention_out"] ) - return attention_BHSDh - - def check( - self, - context: _basics.MatchContext, - query_BHSD, - key_BHkvSD, - value_BHkvSD, - past_key_BHkvSpD, - past_value_BHkvSpD, - present_key_BHStD, - present_value_BHStD, - **_, - ): - bindings: dict[str, Dim] = {} - # Check that inputs to new Attention node have expected shapes - _fusion_utils.check_shape(bindings, query_BHSD, ["B", "H", "S", "D"]) - _fusion_utils.check_shape(bindings, key_BHkvSD, ["B", "Hkv", "S", "D"]) - _fusion_utils.check_shape(bindings, value_BHkvSD, ["B", "Hkv", "S", "D"]) - _fusion_utils.check_shape(bindings, past_key_BHkvSpD, ["B", "Hkv", "P", "D"]) - _fusion_utils.check_shape(bindings, past_value_BHkvSpD, ["B", "Hkv", "P", "D"]) - # We need to check that the Expand/Reshape arguments are as expected. - # As a substitute, we check that the outputs of Expand=>Reshape have expected shapes. - # TODO (rama): May be better to check the actual Expand/Reshape arguments. - _fusion_utils.check_shape(bindings, present_key_BHStD, ["B", "H", "S+P", "D"]) - _fusion_utils.check_shape(bindings, present_value_BHStD, ["B", "H", "S+P", "D"]) - - return True + return attention_out, present_key, present_value def rewrite( self, op, - query_BHSD, - key_BHkvSD, - value_BHkvSD, - past_key_BHkvSpD, - past_value_BHkvSpD, - mask, - attention_BHSDh, + query: ir.Value, + key: ir.Value, + value: ir.Value, + mask: ir.Value, + past_key: ir.Value, + past_value: ir.Value, + attention_out: ir.Value, **_, ): - original_attention_node = attention_BHSDh.producer() + original_attention_node = attention_out.producer() + assert original_attention_node is not None original_attrs = original_attention_node.attributes return op.Attention( - query_BHSD, - key_BHkvSD, - value_BHkvSD, - mask, - past_key_BHkvSpD, - past_value_BHkvSpD, - **original_attrs, + query, key, value, mask, past_key, past_value, **original_attrs, _outputs=3 ) -_basic_gqa_rule = - -gqa_rules = pattern.RewriteRuleSet([AttentionPresentKeyValue.rule()]) +attention_present_key_value_rule = AttentionPresentKeyValue.rule() -fuse_gqa = _fusion_utils.apply_fusion_rules(gqa_rules) +fuse_attention_present_key_value = _fusion_utils.apply_fusion_rules( + attention_present_key_value_rule +) From d7d2f5da6ffdabde35e7242bbba544c2b95d0df3 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 14 Oct 2025 15:54:50 -0700 Subject: [PATCH 3/4] docs Signed-off-by: Justin Chu --- .../rewriter/rules/fusion/_attention_present_kv.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/onnxscript/rewriter/rules/fusion/_attention_present_kv.py b/onnxscript/rewriter/rules/fusion/_attention_present_kv.py index dbcf438e42..6c686034f2 100644 --- a/onnxscript/rewriter/rules/fusion/_attention_present_kv.py +++ b/onnxscript/rewriter/rules/fusion/_attention_present_kv.py @@ -9,7 +9,14 @@ class AttentionPresentKeyValue(pattern.RewriteRuleClassBase): - """Move present_key and present_value to be generated by Attention.""" + """Move present_key and present_value to be generated by Attention. + + When torch.onnx exports a model from transformers with SDPA, it generates a Concat + node to concatenate past_key/value with the new key/value to produce the graph output + for kv cache. This pattern can be fused into the Attention node, which has present_key + and present_value outputs. It is necessary for ONNX Runtime because it requires the outputs + to be produced by the Attention node when past_key and past_value inputs are provided. + """ def pattern( self, From c408516613088b83d742d42e6b1fd2ed507f23e3 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 14 Oct 2025 15:55:48 -0700 Subject: [PATCH 4/4] Add to default fusion Signed-off-by: Justin Chu --- onnxscript/rewriter/onnx_fusions/_onnx_fusions.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/onnxscript/rewriter/onnx_fusions/_onnx_fusions.py b/onnxscript/rewriter/onnx_fusions/_onnx_fusions.py index 008a995764..c5e0653ac9 100644 --- a/onnxscript/rewriter/onnx_fusions/_onnx_fusions.py +++ b/onnxscript/rewriter/onnx_fusions/_onnx_fusions.py @@ -4,7 +4,12 @@ import onnx_ir as ir -from onnxscript.rewriter.rules.fusion import _gqa, _rms_normalization, _rotary_embedding +from onnxscript.rewriter.rules.fusion import ( + _attention_present_kv, + _gqa, + _rms_normalization, + _rotary_embedding, +) def _get_onnx_opset_version(model: ir.Model) -> int | None: @@ -25,6 +30,9 @@ def _opset_23_fuse(model: ir.Model, *, debug: bool = False) -> dict[str, int]: counts["RMSNormalization"] = _rms_normalization.fuse_rms_normalization(model, debug=debug) counts["RotaryEmbedding"] = _rotary_embedding.fuse_rotary_embedding(model, debug=debug) counts["GQA"] = _gqa.fuse_gqa(model, debug=debug) + counts["AttentionPresentKeyValue"] = ( + _attention_present_kv.fuse_attention_present_key_value(model, debug=debug) + ) return counts