@@ -361,6 +361,281 @@ def test_fusion(self):
361361 assert_allclose (outputs3 , source_model_outputs )
362362
363363
364+ class GemmaGQAFusionTest (unittest .TestCase ):
365+ def __init__ (self , * args , ** kwargs ):
366+ super ().__init__ (* args , ** kwargs )
367+ # Config parameters
368+ self .batchsize = 1 # Note: GQA (cpu) seems to require batch-size 1?
369+ self .seqlen = 8
370+ self .kv_seqlen = self .seqlen
371+ self .past_seqlen = 16
372+ self .head_size = 16
373+ self .num_heads = 20
374+ self .kv_num_heads = 10
375+
376+ # Computed config parameters
377+ self .hidden_size = self .head_size * self .num_heads
378+ self .kv_hidden_size = self .head_size * self .kv_num_heads
379+ assert (self .num_heads % self .kv_num_heads ) == 0 , (
380+ "num_heads must be divisible by kv_num_heads"
381+ )
382+ self .num_groups = self .num_heads // self .kv_num_heads
383+ self .total_seqlen = self .seqlen + self .past_seqlen
384+
385+ # Abbreviations
386+ B = self .batchsize
387+ S = self .seqlen
388+ P = self .past_seqlen
389+ D = self .hidden_size
390+ Dkv = self .kv_hidden_size
391+ Dh = self .head_size
392+ Hkv = self .kv_num_heads
393+ total_seqlen = S + P
394+ max_seqlen = total_seqlen
395+
396+ # Input/output types have some dimensions as dynamic (even though the
397+ # test case instance has specific values above).
398+ self .input_types = (
399+ FLOAT ["B" , "S" , D ], # query
400+ FLOAT ["B" , "S" , Dkv ], # key
401+ FLOAT ["B" , "S" , Dkv ], # value
402+ FLOAT ["B" , Hkv , "P" , Dh ], # past_key
403+ FLOAT ["B" , Hkv , "P" , Dh ], # past_value
404+ FLOAT ["max_seqlen" , Dh // 2 ], # cos
405+ FLOAT ["max_seqlen" , Dh // 2 ], # sin
406+ FLOAT ["Dh" ], # query_scale
407+ FLOAT ["Dh" ], # key_scale
408+ )
409+ self .output_types = (
410+ FLOAT ["B" , "S" , D ], # attention
411+ FLOAT ["B" , Hkv , "T" , Dh ], # present_key
412+ FLOAT ["B" , Hkv , "T" , Dh ], # present_value
413+ )
414+
415+ self .inputs = {
416+ "query" : np .random .rand (B , S , D ).astype (np .float32 ),
417+ "key" : np .random .rand (B , S , Dkv ).astype (np .float32 ),
418+ "value" : np .random .rand (B , S , Dkv ).astype (np .float32 ),
419+ "past_key" : np .random .rand (B , Hkv , P , Dh ).astype (np .float32 ),
420+ "past_value" : np .random .rand (B , Hkv , P , Dh ).astype (np .float32 ),
421+ "cos" : np .random .rand (max_seqlen , Dh // 2 ).astype (np .float32 ),
422+ "sin" : np .random .rand (max_seqlen , Dh // 2 ).astype (np .float32 ),
423+ "query_scale" : np .random .rand (Dh ).astype (np .float32 ),
424+ "key_scale" : np .random .rand (Dh ).astype (np .float32 ),
425+ }
426+
427+ def source_model_script (self ):
428+ scale_factor = math .sqrt (math .sqrt (self .head_size ))
429+ minval = torch .finfo (torch .float32 ).min
430+ minval_tp = onnx .helper .make_tensor ("minval" , onnx .TensorProto .FLOAT , [1 ], [minval ])
431+ H = [self .num_heads ]
432+ Hkv = [self .kv_num_heads ]
433+ Dh = [self .head_size ]
434+ G = [self .num_groups ]
435+ minus_1 = [- 1 ] # inferred dimension in Reshape op
436+ plus_1 = [1 ]
437+
438+ @script ()
439+ def gqa (query , key , value , past_key , past_value , cos , sin , query_scale , key_scale ):
440+ # Shapes used for Reshape ops. Note that we have a few different options on how shapes are
441+ # specified in an ONNX Reshape op (which supports special values 0 and -1 to propagate
442+ # existing dimension and one inferred dimension respectively). The following shapes are
443+ # based on what is observed in Phi models generated by the exporter.
444+ B = op .Shape (query , start = 0 , end = 1 )
445+ S = op .Shape (query , start = 1 , end = 2 )
446+ past_seq_length = op .Shape (past_key , start = 2 , end = 3 )
447+ total_seq_length = op .Add (past_seq_length , S )
448+
449+ shape_BSHDh = op .Concat (B , S , minus_1 , Dh , axis = 0 )
450+ shape_BSHkvDh = op .Concat (B , S , minus_1 , Dh , axis = 0 )
451+ shape_BSD = op .Concat (B , S , minus_1 , axis = 0 )
452+ shape_BHkvGSDh = op .Concat (B , Hkv , G , total_seq_length , Dh , axis = 0 )
453+
454+ shape_BHSDh = op .Concat (B , H , total_seq_length , Dh , axis = 0 )
455+
456+ # First, get Q, K, V into right shapes. Inputs are 3D tensors in the BSD format.
457+ # D is different for Q and K/V (not reflected in the names, unfortunately).
458+ # We convert them into BHSDh (i.e., BHSd) format. In this version, we have only
459+ # one sequence length (S) for all Q, K, and V (with no cache).
460+ query_BSHDh = op .Reshape (query , shape_BSHDh )
461+ query_BHSDh = op .Transpose (query_BSHDh , perm = [0 , 2 , 1 , 3 ])
462+ query_BHSDh_normalized = op .SimplifiedLayerNormalization (
463+ query_BHSDh , query_scale , axis = - 1 , epsilon = 1e-06 , stash_type = 1
464+ )
465+
466+ key_BSHkvDh = op .Reshape (key , shape_BSHkvDh )
467+ key_BHkvSDh = op .Transpose (key_BSHkvDh , perm = [0 , 2 , 1 , 3 ])
468+ key_BHkvSDh_normalized = op .SimplifiedLayerNormalization (
469+ key_BHkvSDh , key_scale , axis = - 1 , epsilon = 1e-06 , stash_type = 1
470+ )
471+
472+ value_BSHkvDh = op .Reshape (value , shape_BSHkvDh )
473+ value_BHkvSDh = op .Transpose (value_BSHkvDh , perm = [0 , 2 , 1 , 3 ])
474+
475+ # Concat past and do rotary embedding
476+ position_ids_1d = op .Range (past_seq_length , total_seq_length , 1 )
477+ position_ids_q = op .Unsqueeze (position_ids_1d , [0 ])
478+ position_ids_k = op .Unsqueeze (position_ids_1d , [0 ])
479+
480+ query_BHSDh_rope = msft_op .RotaryEmbedding (
481+ query_BHSDh_normalized ,
482+ position_ids_q ,
483+ cos ,
484+ sin ,
485+ )
486+ key_BHkvSDh_rope = msft_op .RotaryEmbedding (
487+ key_BHkvSDh_normalized ,
488+ position_ids_k ,
489+ cos ,
490+ sin ,
491+ )
492+ key_seq_BHkvSkvDh = op .Concat (past_key , key_BHkvSDh_rope , axis = - 2 )
493+
494+ value_seq_BHkvSkvDh = op .Concat (past_value , value_BHkvSDh , axis = - 2 )
495+
496+ # Now, expand from shared heads to all heads
497+ key_BHkv1SDh = op .Unsqueeze (key_seq_BHkvSkvDh , 2 )
498+ key_BHkvGSDh = op .Expand (key_BHkv1SDh , shape_BHkvGSDh )
499+ key_BHSDh = op .Reshape (key_BHkvGSDh , shape_BHSDh )
500+
501+ value_BHkv1SDh = op .Unsqueeze (value_seq_BHkvSkvDh , 2 )
502+ value_BHkvGSDh = op .Expand (value_BHkv1SDh , shape_BHkvGSDh )
503+ value_BHSDh = op .Reshape (value_BHkvGSDh , shape_BHSDh )
504+
505+ # Generate causal mask:
506+ # where every row looks like [0, 0, ..., /*diagonal=*/ 0, minval, minval, ...]
507+ seq_len = op .Shape (query , end = 2 , start = 1 )
508+ seq_len_0D = op .Squeeze (seq_len )
509+
510+ past_seq_len_0D = op .Squeeze (past_seq_length )
511+
512+ total_seq_len_0D = op .Add (past_seq_len_0D , seq_len_0D )
513+ total_seq_len = op .Reshape (total_seq_len_0D , [- 1 ])
514+
515+ # The Phi modeling code generates the following +1 as the target-length, which seems
516+ # unnecessary in this context. But duplicating same logic here.
517+ total_seq_len_plus_1_0D = op .Add (total_seq_len_0D , 1 )
518+ total_seq_len_plus_1 = op .Reshape (total_seq_len_plus_1_0D , [- 1 ])
519+
520+ current_range = op .Range (past_seq_len_0D , total_seq_len_0D , 1 )
521+ mask_shape = op .Concat (seq_len , total_seq_len_plus_1 , axis = 0 )
522+ min_val = op .Constant (value = minval_tp )
523+ mask_all_min = op .Expand (min_val , mask_shape )
524+ total_range_as_row = op .Range (0 , total_seq_len_plus_1_0D , 1 )
525+ current_range_as_column = op .Reshape (current_range , [- 1 , 1 ])
526+ boolean_mask = op .Greater (total_range_as_row , current_range_as_column )
527+ float_0_1_mask = op .Cast (boolean_mask , to = 1 )
528+ float_0_min_mask = op .Mul (mask_all_min , float_0_1_mask )
529+ mask_4d = op .Unsqueeze (float_0_min_mask , [0 , 1 ])
530+ shape_B111 = op .Concat (B , plus_1 , plus_1 , plus_1 , axis = 0 )
531+ mask_B1ST_plus = op .Expand (mask_4d , shape_B111 )
532+
533+ # Get rid of the extra +1 added above: total_seq_len is enough, no
534+ # need for total_seq_len+1.
535+ mask_B1ST = op .Slice (mask_B1ST_plus , [0 ], total_seq_len , [3 ], [1 ])
536+
537+ # Now, compute attention:
538+ key_transposed = op .Transpose (key_BHSDh , perm = [0 , 1 , 3 , 2 ])
539+ divisor = op .Constant (value_float = scale_factor )
540+ scaled_query = op .Div (query_BHSDh_rope , divisor )
541+ scaled_key = op .Div (key_transposed , divisor )
542+ attn_score = op .MatMul (scaled_query , scaled_key )
543+ masked_attn_score = op .Add (attn_score , mask_B1ST )
544+ attn_weight = op .Softmax (masked_attn_score , axis = - 1 )
545+ attention_BHSDh = op .MatMul (attn_weight , value_BHSDh )
546+
547+ # Reshape back to BSD format
548+ attention_BSHDh = op .Transpose (attention_BHSDh , perm = [0 , 2 , 1 , 3 ])
549+ attention_BSD = op .Reshape (attention_BSHDh , shape_BSD )
550+
551+ return attention_BSD , key_seq_BHkvSkvDh , value_seq_BHkvSkvDh
552+
553+ return gqa
554+
555+ def test_fusion (self ):
556+ """Test that GQA fusion is successful on source model and produces an equivalent model."""
557+ inputs = self .inputs
558+
559+ source_model = self .source_model_script ().to_model_proto (
560+ input_types = self .input_types ,
561+ output_types = self .output_types ,
562+ )
563+ session = ort .InferenceSession (
564+ source_model .SerializeToString (), providers = ("CPUExecutionProvider" ,)
565+ )
566+ source_model_outputs = session .run (None , inputs )
567+
568+ # Some shapes need to be present in input model for fusion to be successful.
569+ # (i) Shape inference doesn't handle handle ORT contrib ops.
570+ # (ii) TODO: investigate if Reshape(..., ["B", "S", -1, Dh]) handled precisely
571+ # by shape inference.
572+ query_BHSDh_rope_value_info = onnx .helper .make_tensor_value_info (
573+ "query_BHSDh_rope" ,
574+ onnx .TensorProto .FLOAT ,
575+ ["B" , self .num_heads , self .seqlen , self .head_size ],
576+ )
577+ key_BHkvSDh_rope_value_info = onnx .helper .make_tensor_value_info (
578+ "key_BHkvSDh_rope" ,
579+ onnx .TensorProto .FLOAT ,
580+ ["B" , self .kv_num_heads , self .seqlen , self .head_size ],
581+ )
582+ query_BSHDh_value_info = onnx .helper .make_tensor_value_info (
583+ "query_BSHDh" ,
584+ onnx .TensorProto .FLOAT ,
585+ ["B" , self .seqlen , self .num_heads , self .head_size ],
586+ )
587+ key_BHSDh_value_info = onnx .helper .make_tensor_value_info (
588+ "key_BHSDh" ,
589+ onnx .TensorProto .FLOAT ,
590+ ["B" , self .num_heads , self .total_seqlen , self .head_size ],
591+ )
592+ key_BSHkvDh_value_info = onnx .helper .make_tensor_value_info (
593+ "key_BSHkvDh" ,
594+ onnx .TensorProto .FLOAT ,
595+ ["B" , self .seqlen , self .kv_num_heads , self .head_size ],
596+ )
597+ key_transposed_value_info = onnx .helper .make_tensor_value_info (
598+ "key_transposed" ,
599+ onnx .TensorProto .FLOAT ,
600+ ["B" , self .num_heads , self .head_size , self .total_seqlen ],
601+ )
602+ value_BHSDh_value_info = onnx .helper .make_tensor_value_info (
603+ "value_BHSDh" ,
604+ onnx .TensorProto .FLOAT ,
605+ ["B" , self .num_heads , self .total_seqlen , self .head_size ],
606+ )
607+ source_model .graph .value_info .extend (
608+ [
609+ query_BHSDh_rope_value_info ,
610+ key_BHkvSDh_rope_value_info ,
611+ query_BSHDh_value_info ,
612+ key_BHSDh_value_info ,
613+ key_BSHkvDh_value_info ,
614+ key_transposed_value_info ,
615+ value_BHSDh_value_info ,
616+ ]
617+ )
618+
619+ source_model_ir = ir .serde .from_proto (source_model )
620+ inferred_model = shape_inference .infer_shapes (source_model_ir )
621+ onnxscript .optimizer .optimize (inferred_model )
622+
623+ count = fuse_sdpa (inferred_model , debug = True )
624+ self .assertGreater (count , 0 )
625+
626+ count = fuse_gqa (inferred_model , debug = True )
627+ self .assertGreater (count , 0 )
628+
629+ fused_model = ir .serde .to_proto (inferred_model )
630+ session = ort .InferenceSession (
631+ fused_model .SerializeToString (), providers = ("CPUExecutionProvider" ,)
632+ )
633+ outputs3 = session .run (None , inputs )
634+
635+ self .assertEqual (len (outputs3 ), len (source_model_outputs ))
636+ assert_allclose (outputs3 , source_model_outputs )
637+
638+
364639class GQAFusionTest2 (unittest .TestCase ):
365640 @unittest .skip ("Needs too much memory." )
366641 def test_phi4lm (self ):
0 commit comments