Skip to content

Commit 80f28c9

Browse files
authored
Add Gemma3 GQA fusion test case (#2642)
Add Gemma3 GQA fusion test case: variant with SimplifiedLayerNormalization applied to query and key. --------- Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
1 parent 55f5b82 commit 80f28c9

File tree

1 file changed

+275
-0
lines changed

1 file changed

+275
-0
lines changed

onnxscript/rewriter/ort_fusions/gqa_test.py

Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,281 @@ def test_fusion(self):
361361
assert_allclose(outputs3, source_model_outputs)
362362

363363

364+
class GemmaGQAFusionTest(unittest.TestCase):
365+
def __init__(self, *args, **kwargs):
366+
super().__init__(*args, **kwargs)
367+
# Config parameters
368+
self.batchsize = 1 # Note: GQA (cpu) seems to require batch-size 1?
369+
self.seqlen = 8
370+
self.kv_seqlen = self.seqlen
371+
self.past_seqlen = 16
372+
self.head_size = 16
373+
self.num_heads = 20
374+
self.kv_num_heads = 10
375+
376+
# Computed config parameters
377+
self.hidden_size = self.head_size * self.num_heads
378+
self.kv_hidden_size = self.head_size * self.kv_num_heads
379+
assert (self.num_heads % self.kv_num_heads) == 0, (
380+
"num_heads must be divisible by kv_num_heads"
381+
)
382+
self.num_groups = self.num_heads // self.kv_num_heads
383+
self.total_seqlen = self.seqlen + self.past_seqlen
384+
385+
# Abbreviations
386+
B = self.batchsize
387+
S = self.seqlen
388+
P = self.past_seqlen
389+
D = self.hidden_size
390+
Dkv = self.kv_hidden_size
391+
Dh = self.head_size
392+
Hkv = self.kv_num_heads
393+
total_seqlen = S + P
394+
max_seqlen = total_seqlen
395+
396+
# Input/output types have some dimensions as dynamic (even though the
397+
# test case instance has specific values above).
398+
self.input_types = (
399+
FLOAT["B", "S", D], # query
400+
FLOAT["B", "S", Dkv], # key
401+
FLOAT["B", "S", Dkv], # value
402+
FLOAT["B", Hkv, "P", Dh], # past_key
403+
FLOAT["B", Hkv, "P", Dh], # past_value
404+
FLOAT["max_seqlen", Dh // 2], # cos
405+
FLOAT["max_seqlen", Dh // 2], # sin
406+
FLOAT["Dh"], # query_scale
407+
FLOAT["Dh"], # key_scale
408+
)
409+
self.output_types = (
410+
FLOAT["B", "S", D], # attention
411+
FLOAT["B", Hkv, "T", Dh], # present_key
412+
FLOAT["B", Hkv, "T", Dh], # present_value
413+
)
414+
415+
self.inputs = {
416+
"query": np.random.rand(B, S, D).astype(np.float32),
417+
"key": np.random.rand(B, S, Dkv).astype(np.float32),
418+
"value": np.random.rand(B, S, Dkv).astype(np.float32),
419+
"past_key": np.random.rand(B, Hkv, P, Dh).astype(np.float32),
420+
"past_value": np.random.rand(B, Hkv, P, Dh).astype(np.float32),
421+
"cos": np.random.rand(max_seqlen, Dh // 2).astype(np.float32),
422+
"sin": np.random.rand(max_seqlen, Dh // 2).astype(np.float32),
423+
"query_scale": np.random.rand(Dh).astype(np.float32),
424+
"key_scale": np.random.rand(Dh).astype(np.float32),
425+
}
426+
427+
def source_model_script(self):
428+
scale_factor = math.sqrt(math.sqrt(self.head_size))
429+
minval = torch.finfo(torch.float32).min
430+
minval_tp = onnx.helper.make_tensor("minval", onnx.TensorProto.FLOAT, [1], [minval])
431+
H = [self.num_heads]
432+
Hkv = [self.kv_num_heads]
433+
Dh = [self.head_size]
434+
G = [self.num_groups]
435+
minus_1 = [-1] # inferred dimension in Reshape op
436+
plus_1 = [1]
437+
438+
@script()
439+
def gqa(query, key, value, past_key, past_value, cos, sin, query_scale, key_scale):
440+
# Shapes used for Reshape ops. Note that we have a few different options on how shapes are
441+
# specified in an ONNX Reshape op (which supports special values 0 and -1 to propagate
442+
# existing dimension and one inferred dimension respectively). The following shapes are
443+
# based on what is observed in Phi models generated by the exporter.
444+
B = op.Shape(query, start=0, end=1)
445+
S = op.Shape(query, start=1, end=2)
446+
past_seq_length = op.Shape(past_key, start=2, end=3)
447+
total_seq_length = op.Add(past_seq_length, S)
448+
449+
shape_BSHDh = op.Concat(B, S, minus_1, Dh, axis=0)
450+
shape_BSHkvDh = op.Concat(B, S, minus_1, Dh, axis=0)
451+
shape_BSD = op.Concat(B, S, minus_1, axis=0)
452+
shape_BHkvGSDh = op.Concat(B, Hkv, G, total_seq_length, Dh, axis=0)
453+
454+
shape_BHSDh = op.Concat(B, H, total_seq_length, Dh, axis=0)
455+
456+
# First, get Q, K, V into right shapes. Inputs are 3D tensors in the BSD format.
457+
# D is different for Q and K/V (not reflected in the names, unfortunately).
458+
# We convert them into BHSDh (i.e., BHSd) format. In this version, we have only
459+
# one sequence length (S) for all Q, K, and V (with no cache).
460+
query_BSHDh = op.Reshape(query, shape_BSHDh)
461+
query_BHSDh = op.Transpose(query_BSHDh, perm=[0, 2, 1, 3])
462+
query_BHSDh_normalized = op.SimplifiedLayerNormalization(
463+
query_BHSDh, query_scale, axis=-1, epsilon=1e-06, stash_type=1
464+
)
465+
466+
key_BSHkvDh = op.Reshape(key, shape_BSHkvDh)
467+
key_BHkvSDh = op.Transpose(key_BSHkvDh, perm=[0, 2, 1, 3])
468+
key_BHkvSDh_normalized = op.SimplifiedLayerNormalization(
469+
key_BHkvSDh, key_scale, axis=-1, epsilon=1e-06, stash_type=1
470+
)
471+
472+
value_BSHkvDh = op.Reshape(value, shape_BSHkvDh)
473+
value_BHkvSDh = op.Transpose(value_BSHkvDh, perm=[0, 2, 1, 3])
474+
475+
# Concat past and do rotary embedding
476+
position_ids_1d = op.Range(past_seq_length, total_seq_length, 1)
477+
position_ids_q = op.Unsqueeze(position_ids_1d, [0])
478+
position_ids_k = op.Unsqueeze(position_ids_1d, [0])
479+
480+
query_BHSDh_rope = msft_op.RotaryEmbedding(
481+
query_BHSDh_normalized,
482+
position_ids_q,
483+
cos,
484+
sin,
485+
)
486+
key_BHkvSDh_rope = msft_op.RotaryEmbedding(
487+
key_BHkvSDh_normalized,
488+
position_ids_k,
489+
cos,
490+
sin,
491+
)
492+
key_seq_BHkvSkvDh = op.Concat(past_key, key_BHkvSDh_rope, axis=-2)
493+
494+
value_seq_BHkvSkvDh = op.Concat(past_value, value_BHkvSDh, axis=-2)
495+
496+
# Now, expand from shared heads to all heads
497+
key_BHkv1SDh = op.Unsqueeze(key_seq_BHkvSkvDh, 2)
498+
key_BHkvGSDh = op.Expand(key_BHkv1SDh, shape_BHkvGSDh)
499+
key_BHSDh = op.Reshape(key_BHkvGSDh, shape_BHSDh)
500+
501+
value_BHkv1SDh = op.Unsqueeze(value_seq_BHkvSkvDh, 2)
502+
value_BHkvGSDh = op.Expand(value_BHkv1SDh, shape_BHkvGSDh)
503+
value_BHSDh = op.Reshape(value_BHkvGSDh, shape_BHSDh)
504+
505+
# Generate causal mask:
506+
# where every row looks like [0, 0, ..., /*diagonal=*/ 0, minval, minval, ...]
507+
seq_len = op.Shape(query, end=2, start=1)
508+
seq_len_0D = op.Squeeze(seq_len)
509+
510+
past_seq_len_0D = op.Squeeze(past_seq_length)
511+
512+
total_seq_len_0D = op.Add(past_seq_len_0D, seq_len_0D)
513+
total_seq_len = op.Reshape(total_seq_len_0D, [-1])
514+
515+
# The Phi modeling code generates the following +1 as the target-length, which seems
516+
# unnecessary in this context. But duplicating same logic here.
517+
total_seq_len_plus_1_0D = op.Add(total_seq_len_0D, 1)
518+
total_seq_len_plus_1 = op.Reshape(total_seq_len_plus_1_0D, [-1])
519+
520+
current_range = op.Range(past_seq_len_0D, total_seq_len_0D, 1)
521+
mask_shape = op.Concat(seq_len, total_seq_len_plus_1, axis=0)
522+
min_val = op.Constant(value=minval_tp)
523+
mask_all_min = op.Expand(min_val, mask_shape)
524+
total_range_as_row = op.Range(0, total_seq_len_plus_1_0D, 1)
525+
current_range_as_column = op.Reshape(current_range, [-1, 1])
526+
boolean_mask = op.Greater(total_range_as_row, current_range_as_column)
527+
float_0_1_mask = op.Cast(boolean_mask, to=1)
528+
float_0_min_mask = op.Mul(mask_all_min, float_0_1_mask)
529+
mask_4d = op.Unsqueeze(float_0_min_mask, [0, 1])
530+
shape_B111 = op.Concat(B, plus_1, plus_1, plus_1, axis=0)
531+
mask_B1ST_plus = op.Expand(mask_4d, shape_B111)
532+
533+
# Get rid of the extra +1 added above: total_seq_len is enough, no
534+
# need for total_seq_len+1.
535+
mask_B1ST = op.Slice(mask_B1ST_plus, [0], total_seq_len, [3], [1])
536+
537+
# Now, compute attention:
538+
key_transposed = op.Transpose(key_BHSDh, perm=[0, 1, 3, 2])
539+
divisor = op.Constant(value_float=scale_factor)
540+
scaled_query = op.Div(query_BHSDh_rope, divisor)
541+
scaled_key = op.Div(key_transposed, divisor)
542+
attn_score = op.MatMul(scaled_query, scaled_key)
543+
masked_attn_score = op.Add(attn_score, mask_B1ST)
544+
attn_weight = op.Softmax(masked_attn_score, axis=-1)
545+
attention_BHSDh = op.MatMul(attn_weight, value_BHSDh)
546+
547+
# Reshape back to BSD format
548+
attention_BSHDh = op.Transpose(attention_BHSDh, perm=[0, 2, 1, 3])
549+
attention_BSD = op.Reshape(attention_BSHDh, shape_BSD)
550+
551+
return attention_BSD, key_seq_BHkvSkvDh, value_seq_BHkvSkvDh
552+
553+
return gqa
554+
555+
def test_fusion(self):
556+
"""Test that GQA fusion is successful on source model and produces an equivalent model."""
557+
inputs = self.inputs
558+
559+
source_model = self.source_model_script().to_model_proto(
560+
input_types=self.input_types,
561+
output_types=self.output_types,
562+
)
563+
session = ort.InferenceSession(
564+
source_model.SerializeToString(), providers=("CPUExecutionProvider",)
565+
)
566+
source_model_outputs = session.run(None, inputs)
567+
568+
# Some shapes need to be present in input model for fusion to be successful.
569+
# (i) Shape inference doesn't handle handle ORT contrib ops.
570+
# (ii) TODO: investigate if Reshape(..., ["B", "S", -1, Dh]) handled precisely
571+
# by shape inference.
572+
query_BHSDh_rope_value_info = onnx.helper.make_tensor_value_info(
573+
"query_BHSDh_rope",
574+
onnx.TensorProto.FLOAT,
575+
["B", self.num_heads, self.seqlen, self.head_size],
576+
)
577+
key_BHkvSDh_rope_value_info = onnx.helper.make_tensor_value_info(
578+
"key_BHkvSDh_rope",
579+
onnx.TensorProto.FLOAT,
580+
["B", self.kv_num_heads, self.seqlen, self.head_size],
581+
)
582+
query_BSHDh_value_info = onnx.helper.make_tensor_value_info(
583+
"query_BSHDh",
584+
onnx.TensorProto.FLOAT,
585+
["B", self.seqlen, self.num_heads, self.head_size],
586+
)
587+
key_BHSDh_value_info = onnx.helper.make_tensor_value_info(
588+
"key_BHSDh",
589+
onnx.TensorProto.FLOAT,
590+
["B", self.num_heads, self.total_seqlen, self.head_size],
591+
)
592+
key_BSHkvDh_value_info = onnx.helper.make_tensor_value_info(
593+
"key_BSHkvDh",
594+
onnx.TensorProto.FLOAT,
595+
["B", self.seqlen, self.kv_num_heads, self.head_size],
596+
)
597+
key_transposed_value_info = onnx.helper.make_tensor_value_info(
598+
"key_transposed",
599+
onnx.TensorProto.FLOAT,
600+
["B", self.num_heads, self.head_size, self.total_seqlen],
601+
)
602+
value_BHSDh_value_info = onnx.helper.make_tensor_value_info(
603+
"value_BHSDh",
604+
onnx.TensorProto.FLOAT,
605+
["B", self.num_heads, self.total_seqlen, self.head_size],
606+
)
607+
source_model.graph.value_info.extend(
608+
[
609+
query_BHSDh_rope_value_info,
610+
key_BHkvSDh_rope_value_info,
611+
query_BSHDh_value_info,
612+
key_BHSDh_value_info,
613+
key_BSHkvDh_value_info,
614+
key_transposed_value_info,
615+
value_BHSDh_value_info,
616+
]
617+
)
618+
619+
source_model_ir = ir.serde.from_proto(source_model)
620+
inferred_model = shape_inference.infer_shapes(source_model_ir)
621+
onnxscript.optimizer.optimize(inferred_model)
622+
623+
count = fuse_sdpa(inferred_model, debug=True)
624+
self.assertGreater(count, 0)
625+
626+
count = fuse_gqa(inferred_model, debug=True)
627+
self.assertGreater(count, 0)
628+
629+
fused_model = ir.serde.to_proto(inferred_model)
630+
session = ort.InferenceSession(
631+
fused_model.SerializeToString(), providers=("CPUExecutionProvider",)
632+
)
633+
outputs3 = session.run(None, inputs)
634+
635+
self.assertEqual(len(outputs3), len(source_model_outputs))
636+
assert_allclose(outputs3, source_model_outputs)
637+
638+
364639
class GQAFusionTest2(unittest.TestCase):
365640
@unittest.skip("Needs too much memory.")
366641
def test_phi4lm(self):

0 commit comments

Comments
 (0)