Skip to content

Commit dd8cb69

Browse files
[DRAFT] Extend GQA fusion for Gemma3 (#2639)
Gemma3 applies an extra (simplified) normalization to query and key before the rotary embedding. Extend GQA fusion to handle this. TODO: will add test-case separately. Signed-off-by: Ganesan Ramalingam <grama@microsoft.com> Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent 8089bc7 commit dd8cb69

File tree

1 file changed

+39
-0
lines changed
  • onnxscript/rewriter/ort_fusions

1 file changed

+39
-0
lines changed

onnxscript/rewriter/ort_fusions/gqa.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,11 +166,23 @@ def pattern(
166166
# Transpose from (B, S, H, D/H) to (B, H, S, D/H)
167167
query_BHSDh = op.Transpose(query_BSHDh, perm=[0, 2, 1, 3])
168168

169+
# Gemma variant uses normalization of query/key before rotary embedding:
170+
query_BHSDh_normalized = op.SimplifiedLayerNormalization(
171+
query_BHSDh, pattern.ANY_VALUE, axis=-1, _outputs=["query_BHSDh_normalized"]
172+
)
173+
query_BHSDh = pattern.OrValue([query_BHSDh, query_BHSDh_normalized])
174+
169175
# Reshape key from (B, S, Dkv) to (B, S, Hkv, D/H)
170176
key_BSHkvDh = op.Reshape(key_BSDkv, pattern.ANY_VALUE, _outputs=["key_BSHkvDh"])
171177
# Transpose from (B, S, Hkv, D/H) to (B, Hkv, S, D/H)
172178
key_BHkvSDh = op.Transpose(key_BSHkvDh, perm=[0, 2, 1, 3])
173179

180+
# Gemma variant uses normalization of query/key before rotary embedding:
181+
key_BHkvSDh_normalized = op.SimplifiedLayerNormalization(
182+
key_BHkvSDh, pattern.ANY_VALUE, axis=-1, _outputs=["key_BHkvSDh_normalized"]
183+
)
184+
key_BHkvSDh = pattern.OrValue([key_BHkvSDh, key_BHkvSDh_normalized])
185+
174186
# Reshape value from (B, S, Dkv) to (B, S, Hkv, D/H)
175187
value_BSHkvDh = op.Reshape(value_BSDkv, pattern.ANY_VALUE, _outputs=["value_BSHkvDh"])
176188
# Transpose from (B, S, Hkv, D/H) to (B, Hkv, S, D/H)
@@ -316,6 +328,10 @@ def rewrite(
316328
cos,
317329
sin,
318330
mask,
331+
query_BSHDh,
332+
key_BSHkvDh,
333+
query_BHSDh_normalized=None,
334+
key_BHkvSDh_normalized=None,
319335
**_,
320336
):
321337
# Note that the following optimization is specific to current ORT GenAI attention-mask
@@ -335,6 +351,29 @@ def rewrite(
335351
seqlens_k = op.Cast(seqlens_k_int64, to=ir.DataType.INT32)
336352
max_seq_length = op.ReduceMax(seqlens_k, zero_int64_1d, keepdims=0)
337353
total_seq_length_int32 = op.Add(max_seq_length, one_int32_0d)
354+
355+
if query_BHSDh_normalized is not None:
356+
# We apply normalization without the transpose, which is fused into GQA
357+
norm_node = query_BHSDh_normalized.producer()
358+
norm_attrs = norm_node.attributes
359+
norm_scale = norm_node.inputs[1]
360+
query_BSHDh_normalized = op.SimplifiedLayerNormalization(
361+
query_BSHDh, norm_scale, **norm_attrs
362+
)
363+
reshape_BSHDh_to_BSD = op.Constant(value_ints=[0, 0, -1])
364+
query_BSD = op.Reshape(query_BSHDh_normalized, reshape_BSHDh_to_BSD)
365+
366+
if key_BHkvSDh_normalized is not None:
367+
# We apply normalization without the transpose, which is fused into GQA
368+
norm_node = key_BHkvSDh_normalized.producer()
369+
norm_attrs = norm_node.attributes
370+
norm_scale = norm_node.inputs[1]
371+
key_BSHkvDh_normalized = op.SimplifiedLayerNormalization(
372+
key_BSHkvDh, norm_scale, **norm_attrs
373+
)
374+
reshape_BSHkvDh_to_BSDkv = op.Constant(value_ints=[0, 0, -1])
375+
key_BSDkv = op.Reshape(key_BSHkvDh_normalized, reshape_BSHkvDh_to_BSDkv)
376+
338377
return op.GroupQueryAttention(
339378
query_BSD,
340379
key_BSDkv,

0 commit comments

Comments
 (0)