@@ -166,11 +166,23 @@ def pattern(
166166 # Transpose from (B, S, H, D/H) to (B, H, S, D/H)
167167 query_BHSDh = op .Transpose (query_BSHDh , perm = [0 , 2 , 1 , 3 ])
168168
169+ # Gemma variant uses normalization of query/key before rotary embedding:
170+ query_BHSDh_normalized = op .SimplifiedLayerNormalization (
171+ query_BHSDh , pattern .ANY_VALUE , axis = - 1 , _outputs = ["query_BHSDh_normalized" ]
172+ )
173+ query_BHSDh = pattern .OrValue ([query_BHSDh , query_BHSDh_normalized ])
174+
169175 # Reshape key from (B, S, Dkv) to (B, S, Hkv, D/H)
170176 key_BSHkvDh = op .Reshape (key_BSDkv , pattern .ANY_VALUE , _outputs = ["key_BSHkvDh" ])
171177 # Transpose from (B, S, Hkv, D/H) to (B, Hkv, S, D/H)
172178 key_BHkvSDh = op .Transpose (key_BSHkvDh , perm = [0 , 2 , 1 , 3 ])
173179
180+ # Gemma variant uses normalization of query/key before rotary embedding:
181+ key_BHkvSDh_normalized = op .SimplifiedLayerNormalization (
182+ key_BHkvSDh , pattern .ANY_VALUE , axis = - 1 , _outputs = ["key_BHkvSDh_normalized" ]
183+ )
184+ key_BHkvSDh = pattern .OrValue ([key_BHkvSDh , key_BHkvSDh_normalized ])
185+
174186 # Reshape value from (B, S, Dkv) to (B, S, Hkv, D/H)
175187 value_BSHkvDh = op .Reshape (value_BSDkv , pattern .ANY_VALUE , _outputs = ["value_BSHkvDh" ])
176188 # Transpose from (B, S, Hkv, D/H) to (B, Hkv, S, D/H)
@@ -316,6 +328,10 @@ def rewrite(
316328 cos ,
317329 sin ,
318330 mask ,
331+ query_BSHDh ,
332+ key_BSHkvDh ,
333+ query_BHSDh_normalized = None ,
334+ key_BHkvSDh_normalized = None ,
319335 ** _ ,
320336 ):
321337 # Note that the following optimization is specific to current ORT GenAI attention-mask
@@ -335,6 +351,29 @@ def rewrite(
335351 seqlens_k = op .Cast (seqlens_k_int64 , to = ir .DataType .INT32 )
336352 max_seq_length = op .ReduceMax (seqlens_k , zero_int64_1d , keepdims = 0 )
337353 total_seq_length_int32 = op .Add (max_seq_length , one_int32_0d )
354+
355+ if query_BHSDh_normalized is not None :
356+ # We apply normalization without the transpose, which is fused into GQA
357+ norm_node = query_BHSDh_normalized .producer ()
358+ norm_attrs = norm_node .attributes
359+ norm_scale = norm_node .inputs [1 ]
360+ query_BSHDh_normalized = op .SimplifiedLayerNormalization (
361+ query_BSHDh , norm_scale , ** norm_attrs
362+ )
363+ reshape_BSHDh_to_BSD = op .Constant (value_ints = [0 , 0 , - 1 ])
364+ query_BSD = op .Reshape (query_BSHDh_normalized , reshape_BSHDh_to_BSD )
365+
366+ if key_BHkvSDh_normalized is not None :
367+ # We apply normalization without the transpose, which is fused into GQA
368+ norm_node = key_BHkvSDh_normalized .producer ()
369+ norm_attrs = norm_node .attributes
370+ norm_scale = norm_node .inputs [1 ]
371+ key_BSHkvDh_normalized = op .SimplifiedLayerNormalization (
372+ key_BSHkvDh , norm_scale , ** norm_attrs
373+ )
374+ reshape_BSHkvDh_to_BSDkv = op .Constant (value_ints = [0 , 0 , - 1 ])
375+ key_BSDkv = op .Reshape (key_BSHkvDh_normalized , reshape_BSHkvDh_to_BSDkv )
376+
338377 return op .GroupQueryAttention (
339378 query_BSD ,
340379 key_BSDkv ,
0 commit comments