Skip to content

Commit 8a7de40

Browse files
authored
Add GQA fusion test cases (#2669)
Add GQA fusion test cases to cover extensions introduced recently to cover patterns seen in Qwen. --------- Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
1 parent 3846705 commit 8a7de40

File tree

1 file changed

+45
-12
lines changed

1 file changed

+45
-12
lines changed

onnxscript/rewriter/ort_fusions/gqa_test.py

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import onnx_ir as ir
1111
import onnx_ir.passes.common.shape_inference as shape_inference
1212
import onnxruntime as ort
13+
import parameterized
1314
import torch
1415

1516
import onnxscript
@@ -361,14 +362,26 @@ def test_fusion(self):
361362
assert_allclose(outputs3, source_model_outputs)
362363

363364

365+
@parameterized.parameterized_class(
366+
[
367+
{"with_past": True, "transpose_first": True},
368+
{"with_past": True, "transpose_first": False},
369+
{"with_past": False, "transpose_first": True},
370+
{"with_past": False, "transpose_first": False},
371+
]
372+
)
364373
class GemmaGQAFusionTest(unittest.TestCase):
374+
with_past = True
375+
transpose_first = True
376+
365377
def __init__(self, *args, **kwargs):
366378
super().__init__(*args, **kwargs)
379+
367380
# Config parameters
368381
self.batchsize = 1 # Note: GQA (cpu) seems to require batch-size 1?
369382
self.seqlen = 8
370383
self.kv_seqlen = self.seqlen
371-
self.past_seqlen = 16
384+
self.past_seqlen = 16 if self.with_past else 0
372385
self.head_size = 16
373386
self.num_heads = 20
374387
self.kv_num_heads = 10
@@ -425,6 +438,8 @@ def __init__(self, *args, **kwargs):
425438
}
426439

427440
def source_model_script(self):
441+
with_past = self.with_past
442+
transpose_first = self.transpose_first
428443
scale_factor = math.sqrt(math.sqrt(self.head_size))
429444
minval = torch.finfo(torch.float32).min
430445
minval_tp = onnx.helper.make_tensor("minval", onnx.TensorProto.FLOAT, [1], [minval])
@@ -458,16 +473,30 @@ def gqa(query, key, value, past_key, past_value, cos, sin, query_scale, key_scal
458473
# We convert them into BHSDh (i.e., BHSd) format. In this version, we have only
459474
# one sequence length (S) for all Q, K, and V (with no cache).
460475
query_BSHDh = op.Reshape(query, shape_BSHDh)
461-
query_BHSDh = op.Transpose(query_BSHDh, perm=[0, 2, 1, 3])
462-
query_BHSDh_normalized = op.SimplifiedLayerNormalization(
463-
query_BHSDh, query_scale, axis=-1, epsilon=1e-06, stash_type=1
464-
)
465-
466476
key_BSHkvDh = op.Reshape(key, shape_BSHkvDh)
467-
key_BHkvSDh = op.Transpose(key_BSHkvDh, perm=[0, 2, 1, 3])
468-
key_BHkvSDh_normalized = op.SimplifiedLayerNormalization(
469-
key_BHkvSDh, key_scale, axis=-1, epsilon=1e-06, stash_type=1
470-
)
477+
478+
if transpose_first:
479+
query_BHSDh = op.Transpose(query_BSHDh, perm=[0, 2, 1, 3])
480+
query_BHSDh_normalized = op.SimplifiedLayerNormalization(
481+
query_BHSDh, query_scale, axis=-1, epsilon=1e-06, stash_type=1
482+
)
483+
key_BHkvSDh = op.Transpose(key_BSHkvDh, perm=[0, 2, 1, 3])
484+
key_BHkvSDh_normalized = op.SimplifiedLayerNormalization(
485+
key_BHkvSDh, key_scale, axis=-1, epsilon=1e-06, stash_type=1
486+
)
487+
else:
488+
query_BSHDh_normalized = op.SimplifiedLayerNormalization(
489+
query_BSHDh, query_scale, axis=-1, epsilon=1e-06, stash_type=1
490+
)
491+
query_BHSDh_normalized = op.Transpose(
492+
query_BSHDh_normalized, perm=[0, 2, 1, 3]
493+
)
494+
key_BSHkvDh_normalized = op.SimplifiedLayerNormalization(
495+
key_BSHkvDh, key_scale, axis=-1, epsilon=1e-06, stash_type=1
496+
)
497+
key_BHkvSDh_normalized = op.Transpose(
498+
key_BSHkvDh_normalized, perm=[0, 2, 1, 3]
499+
)
471500

472501
value_BSHkvDh = op.Reshape(value, shape_BSHkvDh)
473502
value_BHkvSDh = op.Transpose(value_BSHkvDh, perm=[0, 2, 1, 3])
@@ -489,9 +518,13 @@ def gqa(query, key, value, past_key, past_value, cos, sin, query_scale, key_scal
489518
cos,
490519
sin,
491520
)
492-
key_seq_BHkvSkvDh = op.Concat(past_key, key_BHkvSDh_rope, axis=-2)
493521

494-
value_seq_BHkvSkvDh = op.Concat(past_value, value_BHkvSDh, axis=-2)
522+
if with_past:
523+
key_seq_BHkvSkvDh = op.Concat(past_key, key_BHkvSDh_rope, axis=-2)
524+
value_seq_BHkvSkvDh = op.Concat(past_value, value_BHkvSDh, axis=-2)
525+
else:
526+
key_seq_BHkvSkvDh = key_BHkvSDh_rope
527+
value_seq_BHkvSkvDh = value_BHkvSDh
495528

496529
# Now, expand from shared heads to all heads
497530
key_BHkv1SDh = op.Unsqueeze(key_seq_BHkvSkvDh, 2)

0 commit comments

Comments
 (0)