Skip to content

Commit 1a93583

Browse files
[None][feat] Support Yarn on QwQ-32B model (#9059)
Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com> Signed-off-by: Jiang Shao <91270701+StudyingShao@users.noreply.github.com> Co-authored-by: NVJiangShao <91270701+StudyingShao@users.noreply.github.com>
1 parent 1ce483c commit 1a93583

File tree

6 files changed

+97
-28
lines changed

6 files changed

+97
-28
lines changed

cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.cu

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,9 @@ __global__ void fusedQKNormRopeKernel(
7070
float factor, // factor in rope_scaling in config.json. When it is not 1.0, it means the model is using yarn.
7171
float low, // threshold for high frequency
7272
float high, // threshold for low frequency
73-
float attention_factor // attention_factor applied on cos and sin
73+
float attention_factor, // attention_factor applied on cos and sin
74+
// stop of parameters for yarn
75+
bool is_qk_norm // Whether to apply QK norm
7476
)
7577
{
7678
int const warpsPerBlock = blockDim.x / 32;
@@ -136,20 +138,22 @@ __global__ void fusedQKNormRopeKernel(
136138
}
137139
}
138140

139-
// Reduce sum across warp using the utility function
140-
sumOfSquares = tensorrt_llm::common::warpReduceSum(sumOfSquares);
141+
if (is_qk_norm)
142+
{
143+
// Reduce sum across warp using the utility function
144+
sumOfSquares = tensorrt_llm::common::warpReduceSum(sumOfSquares);
141145

142-
// Compute RMS normalization factor
143-
float rms_rcp = rsqrtf(sumOfSquares / static_cast<float>(head_dim) + eps);
146+
// Compute RMS normalization factor
147+
float rms_rcp = rsqrtf(sumOfSquares / static_cast<float>(head_dim) + eps);
144148

145-
// Normalize elements
146-
for (int i = 0; i < numElemsPerThread; i++)
147-
{
148-
int dim = laneId * numElemsPerThread + i;
149-
float weight = isQ ? __bfloat162float(q_weight[dim]) : __bfloat162float(k_weight[dim]);
150-
elements[i] *= rms_rcp * weight;
149+
// Normalize elements
150+
for (int i = 0; i < numElemsPerThread; i++)
151+
{
152+
int dim = laneId * numElemsPerThread + i;
153+
float weight = isQ ? __bfloat162float(q_weight[dim]) : __bfloat162float(k_weight[dim]);
154+
elements[i] *= rms_rcp * weight;
155+
}
151156
}
152-
153157
// Apply RoPE to normalized elements
154158
float elements2[numElemsPerThread]; // Additional buffer required for RoPE.
155159
float cos_vals[numElemsPerThread];
@@ -276,7 +280,7 @@ __global__ void fusedQKNormRopeKernel(
276280
void launchFusedQKNormRope(void* qkv, int const num_tokens, int const num_heads_q, int const num_heads_k,
277281
int const num_heads_v, int const head_dim, float const eps, void const* q_weight, void const* k_weight,
278282
float const base, bool const interleave, int const* position_ids, float factor, float low, float high,
279-
float attention_factor, cudaStream_t stream)
283+
float attention_factor, cudaStream_t stream, bool is_qk_norm)
280284
{
281285
if (factor == 1.0f)
282286
{
@@ -301,23 +305,23 @@ void launchFusedQKNormRope(void* qkv, int const num_tokens, int const num_heads_
301305
fusedQKNormRopeKernel<64, INTERLEAVE><<<gridDim, blockDim, 0, stream>>>(
302306
reinterpret_cast<__nv_bfloat16*>(qkv), num_heads_q, num_heads_k, num_heads_v, eps,
303307
reinterpret_cast<__nv_bfloat16 const*>(q_weight), reinterpret_cast<__nv_bfloat16 const*>(k_weight),
304-
base, position_ids, num_tokens, factor, low, high, attention_factor);
308+
base, position_ids, num_tokens, factor, low, high, attention_factor, is_qk_norm);
305309
});
306310
break;
307311
case 128:
308312
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
309313
fusedQKNormRopeKernel<128, INTERLEAVE><<<gridDim, blockDim, 0, stream>>>(
310314
reinterpret_cast<__nv_bfloat16*>(qkv), num_heads_q, num_heads_k, num_heads_v, eps,
311315
reinterpret_cast<__nv_bfloat16 const*>(q_weight), reinterpret_cast<__nv_bfloat16 const*>(k_weight),
312-
base, position_ids, num_tokens, factor, low, high, attention_factor);
316+
base, position_ids, num_tokens, factor, low, high, attention_factor, is_qk_norm);
313317
});
314318
break;
315319
case 256:
316320
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
317321
fusedQKNormRopeKernel<256, INTERLEAVE><<<gridDim, blockDim, 0, stream>>>(
318322
reinterpret_cast<__nv_bfloat16*>(qkv), num_heads_q, num_heads_k, num_heads_v, eps,
319323
reinterpret_cast<__nv_bfloat16 const*>(q_weight), reinterpret_cast<__nv_bfloat16 const*>(k_weight),
320-
base, position_ids, num_tokens, factor, low, high, attention_factor);
324+
base, position_ids, num_tokens, factor, low, high, attention_factor, is_qk_norm);
321325
});
322326
break;
323327
default: TLLM_THROW("Unsupported head dimension for fusedQKNormRope: %d", head_dim);

cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ void launchFusedQKNormRope(
4242
float low, // threshold for high frequency
4343
float high, // threshold for low frequency
4444
float attention_factor, // attention_factor applied on cos and sin
45-
cudaStream_t stream); // CUDA stream
45+
cudaStream_t stream, // CUDA stream
46+
bool is_qk_norm);
4647

4748
} // namespace kernels
4849
} // namespace tensorrt_llm

cpp/tensorrt_llm/thop/fusedQKNormRopeOp.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ void fused_qk_norm_rope(
4242
double factor, // factor in rope_scaling in config.json. When it is not 1.0, it means the model is using yarn.
4343
double low, // threshold for high frequency
4444
double high, // threshold for low frequency
45-
double attention_factor // attention_factor applied on cos and sin
45+
double attention_factor, // attention_factor applied on cos and sin
46+
bool is_qk_norm // Whether to apply QK norm
4647
)
4748
{
4849
// Input validation
@@ -74,7 +75,7 @@ void fused_qk_norm_rope(
7475
static_cast<float>(base),
7576
!is_neox, // interleave
7677
reinterpret_cast<int const*>(position_ids.data_ptr()), static_cast<float>(factor), static_cast<float>(low),
77-
static_cast<float>(high), static_cast<float>(attention_factor), stream);
78+
static_cast<float>(high), static_cast<float>(attention_factor), stream, is_qk_norm);
7879
}
7980

8081
// Register the PyTorch operators
@@ -83,7 +84,7 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
8384
m.def(
8485
"fused_qk_norm_rope(Tensor(a!) qkv, int num_heads_q, int num_heads_k, int num_heads_v, int head_dim, float "
8586
"eps, Tensor q_weight, Tensor k_weight, float base, bool is_neox, Tensor position_ids, float factor, float "
86-
"low, float high, float attention_factor) -> ()");
87+
"low, float high, float attention_factor, bool is_qk_norm) -> ()");
8788
}
8889

8990
// Register the CUDA implementation

tensorrt_llm/_torch/models/modeling_qwen.py

Lines changed: 63 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from ..modules.embedding import Embedding
1515
from ..modules.gated_mlp import GatedMLP
1616
from ..modules.linear import Linear, TensorParallelMode
17+
from ..modules.qk_norm_attention import QKNormRoPEAttention
1718
from ..modules.rms_norm import RMSNorm
1819
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM,
1920
register_auto_model)
@@ -53,6 +54,56 @@ def __init__(
5354
)
5455

5556

57+
# TODO this is a workaround to support yarn on Qwen2.
58+
# We need refactor the codes to merge QwenYarnAttention and QwenAttention.
59+
class QwenYarnAttention(QKNormRoPEAttention):
60+
61+
def __init__(
62+
self,
63+
model_config: ModelConfig[Qwen2Config],
64+
layer_idx: Optional[int] = None,
65+
fuse_qk_norm_rope: bool = True,
66+
):
67+
config = model_config.pretrained_config
68+
69+
if getattr(config, "rope_scaling", None) is not None:
70+
if "type" in config.rope_scaling:
71+
pos_type = config.rope_scaling["type"]
72+
elif "rope_type" in config.rope_scaling:
73+
pos_type = config.rope_scaling["rope_type"]
74+
else:
75+
raise ValueError(
76+
"rope_scaling must have type or rope_type field")
77+
pos_embd_params = PositionalEmbeddingParams(
78+
type=PositionEmbeddingType.from_string(pos_type),
79+
rope=RopeParams.from_config(config),
80+
)
81+
else:
82+
pos_embd_params = PositionalEmbeddingParams(
83+
type=PositionEmbeddingType.rope_gpt_neox,
84+
rope=RopeParams.from_config(config),
85+
)
86+
87+
# Qwen3 has accuracy issues with deep_gemm (see: https://nvbugspro.nvidia.com/bug/5461712
88+
# and https://nvbugspro.nvidia.com/bug/5505402)
89+
disable_deep_gemm = True
90+
super().__init__(
91+
hidden_size=config.hidden_size,
92+
num_attention_heads=config.num_attention_heads,
93+
num_key_value_heads=config.num_key_value_heads,
94+
max_position_embeddings=config.max_position_embeddings,
95+
bias=True,
96+
pos_embd_params=pos_embd_params,
97+
fuse_qk_norm_rope=fuse_qk_norm_rope,
98+
layer_idx=layer_idx,
99+
dtype=config.torch_dtype,
100+
dense_bias=False,
101+
config=model_config,
102+
disable_deep_gemm=disable_deep_gemm,
103+
is_qk_norm=False,
104+
)
105+
106+
56107
class QwenDecoderLayer(DecoderLayer):
57108

58109
def __init__(
@@ -63,10 +114,18 @@ def __init__(
63114
super().__init__()
64115
self.layer_idx = layer_idx
65116
config = model_config.pretrained_config
66-
self.self_attn = QwenAttention(
67-
model_config,
68-
layer_idx=layer_idx,
69-
)
117+
118+
if getattr(config, "rope_scaling", None) is not None and getattr(
119+
config.rope_scaling, "rope_type", None) == "yarn":
120+
self.self_attn = QwenYarnAttention(
121+
model_config,
122+
layer_idx=layer_idx,
123+
)
124+
else:
125+
self.self_attn = QwenAttention(
126+
model_config,
127+
layer_idx=layer_idx,
128+
)
70129

71130
self.mlp = GatedMLP(
72131
hidden_size=config.hidden_size,

tensorrt_llm/_torch/modules/qk_norm_attention.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ def __init__(
158158
disable_deep_gemm: bool = False,
159159
use_gemma_rms_norm: bool = False,
160160
attn_output_gate: Optional[bool] = None,
161+
is_qk_norm: bool = True,
161162
):
162163
self.pretrained_config = config.pretrained_config
163164

@@ -169,6 +170,7 @@ def __init__(
169170
# If fuse_qk_norm_rope is true, do not apply fused RoPE in attention OP, and self.rotary_emb
170171
# will be skipped in the overridden apply_rope.
171172
rope_fusion = not self.fuse_qk_norm_rope and not skip_rope and not attn_output_gate and not use_gemma_rms_norm
173+
self.is_qk_norm = is_qk_norm
172174
assert not (fuse_qk_norm_rope and skip_rope
173175
), "Fusing qk norm and skipping rope is not supported"
174176

@@ -192,12 +194,12 @@ def __init__(
192194
self.q_norm = RMSNorm(hidden_size=self.head_dim,
193195
eps=self.pretrained_config.rms_norm_eps,
194196
dtype=self.pretrained_config.torch_dtype,
195-
has_weights=True,
197+
has_weights=is_qk_norm,
196198
use_gemma=use_gemma_rms_norm)
197199
self.k_norm = RMSNorm(hidden_size=self.head_dim,
198200
eps=self.pretrained_config.rms_norm_eps,
199201
dtype=self.pretrained_config.torch_dtype,
200-
has_weights=True,
202+
has_weights=is_qk_norm,
201203
use_gemma=use_gemma_rms_norm)
202204
self.aux_stream = torch.cuda.Stream()
203205
self.ln_events = [torch.cuda.Event(), torch.cuda.Event()]
@@ -231,7 +233,8 @@ def apply_qk_norm_rope(self, qkv, position_ids):
231233
self.q_norm.variance_epsilon, self.q_norm.weight,
232234
self.k_norm.weight,
233235
self.pos_embd_params.rope.theta, self.pos_embd_params.is_neox,
234-
position_ids.view(-1), factor, low, high, attention_factor)
236+
position_ids.view(-1), factor, low, high, attention_factor,
237+
self.is_qk_norm)
235238
return qkv, None, None
236239

237240
def apply_rope(self, q: torch.Tensor, k: Optional[torch.Tensor],

tests/unittest/_torch/thop/parallel/test_fused_qk_norm_rope.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,8 @@ def test_fused_qk_norm_rope(head_dim, num_heads_group, num_tokens, is_neox,
147147
torch.ops.trtllm.fused_qk_norm_rope(qkv, num_heads_q, num_heads_k,
148148
num_heads_v, head_dim, eps, q_weight,
149149
k_weight, base, is_neox, position_ids,
150-
factor, low, high, attention_factor)
150+
factor, low, high, attention_factor,
151+
True)
151152
output = qkv # This op is inplace
152153

153154
# Compute reference output using TensorRT LLM modules

0 commit comments

Comments
 (0)