1010import onnx_ir as ir
1111import onnx_ir .passes .common .shape_inference as shape_inference
1212import onnxruntime as ort
13+ import parameterized
1314import torch
1415
1516import onnxscript
@@ -361,14 +362,26 @@ def test_fusion(self):
361362 assert_allclose (outputs3 , source_model_outputs )
362363
363364
365+ @parameterized .parameterized_class (
366+ [
367+ {"with_past" : True , "transpose_first" : True },
368+ {"with_past" : True , "transpose_first" : False },
369+ {"with_past" : False , "transpose_first" : True },
370+ {"with_past" : False , "transpose_first" : False },
371+ ]
372+ )
364373class GemmaGQAFusionTest (unittest .TestCase ):
374+ with_past = True
375+ transpose_first = True
376+
365377 def __init__ (self , * args , ** kwargs ):
366378 super ().__init__ (* args , ** kwargs )
379+
367380 # Config parameters
368381 self .batchsize = 1 # Note: GQA (cpu) seems to require batch-size 1?
369382 self .seqlen = 8
370383 self .kv_seqlen = self .seqlen
371- self .past_seqlen = 16
384+ self .past_seqlen = 16 if self . with_past else 0
372385 self .head_size = 16
373386 self .num_heads = 20
374387 self .kv_num_heads = 10
@@ -425,6 +438,8 @@ def __init__(self, *args, **kwargs):
425438 }
426439
427440 def source_model_script (self ):
441+ with_past = self .with_past
442+ transpose_first = self .transpose_first
428443 scale_factor = math .sqrt (math .sqrt (self .head_size ))
429444 minval = torch .finfo (torch .float32 ).min
430445 minval_tp = onnx .helper .make_tensor ("minval" , onnx .TensorProto .FLOAT , [1 ], [minval ])
@@ -458,16 +473,30 @@ def gqa(query, key, value, past_key, past_value, cos, sin, query_scale, key_scal
458473 # We convert them into BHSDh (i.e., BHSd) format. In this version, we have only
459474 # one sequence length (S) for all Q, K, and V (with no cache).
460475 query_BSHDh = op .Reshape (query , shape_BSHDh )
461- query_BHSDh = op .Transpose (query_BSHDh , perm = [0 , 2 , 1 , 3 ])
462- query_BHSDh_normalized = op .SimplifiedLayerNormalization (
463- query_BHSDh , query_scale , axis = - 1 , epsilon = 1e-06 , stash_type = 1
464- )
465-
466476 key_BSHkvDh = op .Reshape (key , shape_BSHkvDh )
467- key_BHkvSDh = op .Transpose (key_BSHkvDh , perm = [0 , 2 , 1 , 3 ])
468- key_BHkvSDh_normalized = op .SimplifiedLayerNormalization (
469- key_BHkvSDh , key_scale , axis = - 1 , epsilon = 1e-06 , stash_type = 1
470- )
477+
478+ if transpose_first :
479+ query_BHSDh = op .Transpose (query_BSHDh , perm = [0 , 2 , 1 , 3 ])
480+ query_BHSDh_normalized = op .SimplifiedLayerNormalization (
481+ query_BHSDh , query_scale , axis = - 1 , epsilon = 1e-06 , stash_type = 1
482+ )
483+ key_BHkvSDh = op .Transpose (key_BSHkvDh , perm = [0 , 2 , 1 , 3 ])
484+ key_BHkvSDh_normalized = op .SimplifiedLayerNormalization (
485+ key_BHkvSDh , key_scale , axis = - 1 , epsilon = 1e-06 , stash_type = 1
486+ )
487+ else :
488+ query_BSHDh_normalized = op .SimplifiedLayerNormalization (
489+ query_BSHDh , query_scale , axis = - 1 , epsilon = 1e-06 , stash_type = 1
490+ )
491+ query_BHSDh_normalized = op .Transpose (
492+ query_BSHDh_normalized , perm = [0 , 2 , 1 , 3 ]
493+ )
494+ key_BSHkvDh_normalized = op .SimplifiedLayerNormalization (
495+ key_BSHkvDh , key_scale , axis = - 1 , epsilon = 1e-06 , stash_type = 1
496+ )
497+ key_BHkvSDh_normalized = op .Transpose (
498+ key_BSHkvDh_normalized , perm = [0 , 2 , 1 , 3 ]
499+ )
471500
472501 value_BSHkvDh = op .Reshape (value , shape_BSHkvDh )
473502 value_BHkvSDh = op .Transpose (value_BSHkvDh , perm = [0 , 2 , 1 , 3 ])
@@ -489,9 +518,13 @@ def gqa(query, key, value, past_key, past_value, cos, sin, query_scale, key_scal
489518 cos ,
490519 sin ,
491520 )
492- key_seq_BHkvSkvDh = op .Concat (past_key , key_BHkvSDh_rope , axis = - 2 )
493521
494- value_seq_BHkvSkvDh = op .Concat (past_value , value_BHkvSDh , axis = - 2 )
522+ if with_past :
523+ key_seq_BHkvSkvDh = op .Concat (past_key , key_BHkvSDh_rope , axis = - 2 )
524+ value_seq_BHkvSkvDh = op .Concat (past_value , value_BHkvSDh , axis = - 2 )
525+ else :
526+ key_seq_BHkvSkvDh = key_BHkvSDh_rope
527+ value_seq_BHkvSkvDh = value_BHkvSDh
495528
496529 # Now, expand from shared heads to all heads
497530 key_BHkv1SDh = op .Unsqueeze (key_seq_BHkvSkvDh , 2 )
0 commit comments