diff --git a/onnxscript/rewriter/onnx_fusions/_onnx_fusions.py b/onnxscript/rewriter/onnx_fusions/_onnx_fusions.py index 008a995764..c5e0653ac9 100644 --- a/onnxscript/rewriter/onnx_fusions/_onnx_fusions.py +++ b/onnxscript/rewriter/onnx_fusions/_onnx_fusions.py @@ -4,7 +4,12 @@ import onnx_ir as ir -from onnxscript.rewriter.rules.fusion import _gqa, _rms_normalization, _rotary_embedding +from onnxscript.rewriter.rules.fusion import ( + _attention_present_kv, + _gqa, + _rms_normalization, + _rotary_embedding, +) def _get_onnx_opset_version(model: ir.Model) -> int | None: @@ -25,6 +30,9 @@ def _opset_23_fuse(model: ir.Model, *, debug: bool = False) -> dict[str, int]: counts["RMSNormalization"] = _rms_normalization.fuse_rms_normalization(model, debug=debug) counts["RotaryEmbedding"] = _rotary_embedding.fuse_rotary_embedding(model, debug=debug) counts["GQA"] = _gqa.fuse_gqa(model, debug=debug) + counts["AttentionPresentKeyValue"] = ( + _attention_present_kv.fuse_attention_present_key_value(model, debug=debug) + ) return counts diff --git a/onnxscript/rewriter/rules/fusion/_attention_present_kv.py b/onnxscript/rewriter/rules/fusion/_attention_present_kv.py new file mode 100644 index 0000000000..6c686034f2 --- /dev/null +++ b/onnxscript/rewriter/rules/fusion/_attention_present_kv.py @@ -0,0 +1,64 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import onnx_ir as ir + +import onnxscript.rewriter._fusion_utils as _fusion_utils +from onnxscript.rewriter import pattern + + +class AttentionPresentKeyValue(pattern.RewriteRuleClassBase): + """Move present_key and present_value to be generated by Attention. + + When torch.onnx exports a model from transformers with SDPA, it generates a Concat + node to concatenate past_key/value with the new key/value to produce the graph output + for kv cache. This pattern can be fused into the Attention node, which has present_key + and present_value outputs. It is necessary for ONNX Runtime because it requires the outputs + to be produced by the Attention node when past_key and past_value inputs are provided. + """ + + def pattern( + self, + op, + query, + key, + value, + mask, + past_key, + past_value, + ): + present_key = op.Concat(past_key, key, axis=-2) + present_value = op.Concat(past_value, value, axis=-2) + + attention_out = op.Attention( + query, key, value, mask, past_key, past_value, _outputs=["attention_out"] + ) + + return attention_out, present_key, present_value + + def rewrite( + self, + op, + query: ir.Value, + key: ir.Value, + value: ir.Value, + mask: ir.Value, + past_key: ir.Value, + past_value: ir.Value, + attention_out: ir.Value, + **_, + ): + original_attention_node = attention_out.producer() + assert original_attention_node is not None + original_attrs = original_attention_node.attributes + return op.Attention( + query, key, value, mask, past_key, past_value, **original_attrs, _outputs=3 + ) + + +attention_present_key_value_rule = AttentionPresentKeyValue.rule() + +fuse_attention_present_key_value = _fusion_utils.apply_fusion_rules( + attention_present_key_value_rule +)