@@ -70,7 +70,9 @@ __global__ void fusedQKNormRopeKernel(
7070 float factor, // factor in rope_scaling in config.json. When it is not 1.0, it means the model is using yarn.
7171 float low, // threshold for high frequency
7272 float high, // threshold for low frequency
73- float attention_factor // attention_factor applied on cos and sin
73+ float attention_factor, // attention_factor applied on cos and sin
74+ // stop of parameters for yarn
75+ bool is_qk_norm // Whether to apply QK norm
7476)
7577{
7678 int const warpsPerBlock = blockDim .x / 32 ;
@@ -136,20 +138,22 @@ __global__ void fusedQKNormRopeKernel(
136138 }
137139 }
138140
139- // Reduce sum across warp using the utility function
140- sumOfSquares = tensorrt_llm::common::warpReduceSum (sumOfSquares);
141+ if (is_qk_norm)
142+ {
143+ // Reduce sum across warp using the utility function
144+ sumOfSquares = tensorrt_llm::common::warpReduceSum (sumOfSquares);
141145
142- // Compute RMS normalization factor
143- float rms_rcp = rsqrtf (sumOfSquares / static_cast <float >(head_dim) + eps);
146+ // Compute RMS normalization factor
147+ float rms_rcp = rsqrtf (sumOfSquares / static_cast <float >(head_dim) + eps);
144148
145- // Normalize elements
146- for (int i = 0 ; i < numElemsPerThread; i++)
147- {
148- int dim = laneId * numElemsPerThread + i;
149- float weight = isQ ? __bfloat162float (q_weight[dim]) : __bfloat162float (k_weight[dim]);
150- elements[i] *= rms_rcp * weight;
149+ // Normalize elements
150+ for (int i = 0 ; i < numElemsPerThread; i++)
151+ {
152+ int dim = laneId * numElemsPerThread + i;
153+ float weight = isQ ? __bfloat162float (q_weight[dim]) : __bfloat162float (k_weight[dim]);
154+ elements[i] *= rms_rcp * weight;
155+ }
151156 }
152-
153157 // Apply RoPE to normalized elements
154158 float elements2[numElemsPerThread]; // Additional buffer required for RoPE.
155159 float cos_vals[numElemsPerThread];
@@ -276,7 +280,7 @@ __global__ void fusedQKNormRopeKernel(
276280void launchFusedQKNormRope (void * qkv, int const num_tokens, int const num_heads_q, int const num_heads_k,
277281 int const num_heads_v, int const head_dim, float const eps, void const * q_weight, void const * k_weight,
278282 float const base, bool const interleave, int const * position_ids, float factor, float low, float high,
279- float attention_factor, cudaStream_t stream)
283+ float attention_factor, cudaStream_t stream, bool is_qk_norm )
280284{
281285 if (factor == 1 .0f )
282286 {
@@ -301,23 +305,23 @@ void launchFusedQKNormRope(void* qkv, int const num_tokens, int const num_heads_
301305 fusedQKNormRopeKernel<64 , INTERLEAVE><<<gridDim , blockDim , 0 , stream>>> (
302306 reinterpret_cast <__nv_bfloat16*>(qkv), num_heads_q, num_heads_k, num_heads_v, eps,
303307 reinterpret_cast <__nv_bfloat16 const *>(q_weight), reinterpret_cast <__nv_bfloat16 const *>(k_weight),
304- base, position_ids, num_tokens, factor, low, high, attention_factor);
308+ base, position_ids, num_tokens, factor, low, high, attention_factor, is_qk_norm );
305309 });
306310 break ;
307311 case 128 :
308312 DISPATCH_INTERLEAVE (interleave, INTERLEAVE, {
309313 fusedQKNormRopeKernel<128 , INTERLEAVE><<<gridDim , blockDim , 0 , stream>>> (
310314 reinterpret_cast <__nv_bfloat16*>(qkv), num_heads_q, num_heads_k, num_heads_v, eps,
311315 reinterpret_cast <__nv_bfloat16 const *>(q_weight), reinterpret_cast <__nv_bfloat16 const *>(k_weight),
312- base, position_ids, num_tokens, factor, low, high, attention_factor);
316+ base, position_ids, num_tokens, factor, low, high, attention_factor, is_qk_norm );
313317 });
314318 break ;
315319 case 256 :
316320 DISPATCH_INTERLEAVE (interleave, INTERLEAVE, {
317321 fusedQKNormRopeKernel<256 , INTERLEAVE><<<gridDim , blockDim , 0 , stream>>> (
318322 reinterpret_cast <__nv_bfloat16*>(qkv), num_heads_q, num_heads_k, num_heads_v, eps,
319323 reinterpret_cast <__nv_bfloat16 const *>(q_weight), reinterpret_cast <__nv_bfloat16 const *>(k_weight),
320- base, position_ids, num_tokens, factor, low, high, attention_factor);
324+ base, position_ids, num_tokens, factor, low, high, attention_factor, is_qk_norm );
321325 });
322326 break ;
323327 default : TLLM_THROW (" Unsupported head dimension for fusedQKNormRope: %d" , head_dim);
0 commit comments