Skip to content

Commit 1c4a046

Browse files
authored
【OPS】qwen3-next support triton chunk_gated_delta_rule ops (#4070)
### What this PR does / why we need it? qwen3-next suppot triton chunk_gated_delta_rule ops ### co-owners @OsirisDuan - vLLM version: v0.11.2 Signed-off-by: shiyuan680 <917935075@qq.com>
1 parent 5447a03 commit 1c4a046

File tree

13 files changed

+1625
-149
lines changed

13 files changed

+1625
-149
lines changed

.github/workflows/_e2e_test.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ jobs:
276276
shell: bash -l {0}
277277
run: |
278278
. /usr/local/Ascend/ascend-toolkit/8.3.RC2/bisheng_toolkit/set_env.sh
279-
python3 -m pip install "https://vllm-ascend.obs.cn-north-4.myhuaweicloud.com/vllm-ascend/triton_ascend-3.2.0.dev20250914-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl"
279+
python3 -m pip install "https://vllm-ascend.obs.cn-north-4.myhuaweicloud.com/vllm-ascend/triton_ascend-3.2.0.dev2025110717-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl"
280280
281281
- name: Run vllm-project/vllm-ascend Qwen3 Next test
282282
working-directory: ./vllm-ascend
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import torch
2+
3+
from tests.ut.base import PytestBase
4+
from vllm_ascend.ops.triton.fla.chunk import chunk_gated_delta_rule
5+
6+
7+
class TestChunkGatedDeltaRule(PytestBase):
8+
9+
def test_triton_fusion_ops(self, mock_moe_env):
10+
q = torch.randn(1, 17, 4, 128, dtype=torch.bfloat16).npu()
11+
k = torch.randn(1, 17, 4, 128, dtype=torch.bfloat16).npu()
12+
v = torch.randn(1, 17, 8, 128, dtype=torch.bfloat16).npu()
13+
g = torch.randn(1, 17, 8, dtype=torch.float32).npu()
14+
beta = torch.randn(1, 17, 8, dtype=torch.bfloat16).npu()
15+
initial_state = torch.randn(3, 8, 128, 128, dtype=torch.bfloat16).npu()
16+
q_start_loc = torch.range(0, 3, dtype=torch.int).npu()
17+
18+
(
19+
core_attn_out_non_spec,
20+
last_recurrent_state,
21+
) = chunk_gated_delta_rule(q=q,
22+
k=k,
23+
v=v,
24+
g=g,
25+
beta=beta,
26+
initial_state=initial_state,
27+
output_final_state=True,
28+
cu_seqlens=q_start_loc,
29+
head_first=False,
30+
use_qk_l2norm_in_kernel=True)
31+
32+
assert core_attn_out_non_spec.shape == (1, 17, 8, 128)
33+
assert last_recurrent_state.shape == (3, 8, 128, 128)

vllm_ascend/models/qwen3_next.py

Lines changed: 14 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -423,50 +423,20 @@ def _forward(
423423
non_spec_state_indices_tensor].contiguous()
424424
initial_state[~has_initial_state, ...] = 0
425425

426-
batch_size = initial_state.shape[0]
427-
core_attn_out = []
428-
last_recurrent_state = []
429-
430-
for b_idx in range(batch_size):
431-
start, end = non_spec_query_start_loc[
432-
b_idx], non_spec_query_start_loc[b_idx + 1]
433-
cur_q = query_non_spec[:, start:end, ...]
434-
cur_k = key_non_spec[:, start:end, ...]
435-
cur_v = value_non_spec[:, start:end, ...]
436-
cur_g = g_non_spec[:, start:end, ...]
437-
cur_b = beta_non_spec[:, start:end, ...]
438-
cur_state = initial_state[b_idx].unsqueeze(0)
439-
440-
(
441-
cur_core_attn_out_non_spec,
442-
cur_last_recurrent_state,
443-
) = chunk.chunk_gated_delta_rule(
444-
query=cur_q,
445-
key=cur_k,
446-
value=cur_v,
447-
g=cur_g,
448-
beta=cur_b,
449-
initial_state=cur_state,
450-
output_final_state=True,
451-
use_qk_l2norm_in_kernel=True,
452-
)
453-
454-
core_attn_out.append(cur_core_attn_out_non_spec)
455-
last_recurrent_state.append(cur_last_recurrent_state)
456-
457-
tar_dtype = core_attn_out[0].dtype
458-
tar_device = core_attn_out[0].device
459-
tar_shape = list(core_attn_out[0].shape)
460-
tar_shape[1] = non_spec_query_start_loc[-1]
461-
core_attn_out_non_spec = torch.empty(tar_shape,
462-
dtype=tar_dtype,
463-
device=tar_device)
464-
for b_idx in range(batch_size):
465-
cur_core_attn_out = core_attn_out[b_idx]
466-
start, end = non_spec_query_start_loc[
467-
b_idx], non_spec_query_start_loc[b_idx + 1]
468-
core_attn_out_non_spec[:, start:end, ...] = cur_core_attn_out
469-
last_recurrent_state = torch.cat(last_recurrent_state, dim=0)
426+
(
427+
core_attn_out_non_spec,
428+
last_recurrent_state,
429+
) = chunk.chunk_gated_delta_rule(
430+
q=query_non_spec,
431+
k=key_non_spec,
432+
v=value_non_spec,
433+
g=g_non_spec,
434+
beta=beta_non_spec,
435+
initial_state=initial_state,
436+
output_final_state=True,
437+
cu_seqlens=non_spec_query_start_loc,
438+
head_first=False,
439+
use_qk_l2norm_in_kernel=True)
470440

471441
# Init cache
472442
ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to(
Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
4+
#
5+
# This file contains code copied from the flash-linear-attention project.
6+
# The original source code was licensed under the MIT license and included
7+
# the following copyright notice:
8+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
9+
# ruff: noqa: E501
10+
# mypy: ignore-errors
11+
import warnings
12+
from typing import Optional
13+
14+
import torch
15+
from einops import rearrange
16+
from vllm.model_executor.layers.fla.ops.l2norm import l2norm_fwd
17+
from vllm.model_executor.layers.fla.ops.utils import SUPPRESS_LEVEL
18+
19+
from .chunk_delta_h import chunk_gated_delta_rule_fwd_h
20+
from .chunk_o import chunk_fwd_o
21+
from .chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd
22+
from .cumsum import chunk_local_cumsum
23+
from .solve_tril import solve_tril
24+
from .utils import input_guard
25+
from .wy_fast import recompute_w_u_fwd
26+
27+
28+
def chunk_gated_delta_rule_fwd(q: torch.Tensor,
29+
k: torch.Tensor,
30+
v: torch.Tensor,
31+
g: torch.Tensor,
32+
beta: torch.Tensor,
33+
scale: float,
34+
initial_state: torch.Tensor,
35+
output_final_state: bool,
36+
cu_seqlens: Optional[torch.LongTensor] = None):
37+
g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
38+
# obtain WY representation. u is actually the new v.
39+
A = chunk_scaled_dot_kkt_fwd(k=k,
40+
beta=beta,
41+
g_cumsum=g,
42+
cu_seqlens=cu_seqlens,
43+
output_dtype=torch.float32)
44+
A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype)
45+
w, u = recompute_w_u_fwd(
46+
k=k,
47+
v=v,
48+
beta=beta,
49+
A=A,
50+
g_cumsum=g,
51+
cu_seqlens=cu_seqlens,
52+
)
53+
h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
54+
k=k,
55+
w=w,
56+
u=u,
57+
g=g,
58+
initial_state=initial_state,
59+
output_final_state=output_final_state,
60+
cu_seqlens=cu_seqlens,
61+
)
62+
o = chunk_fwd_o(
63+
q=q,
64+
k=k,
65+
v=v_new,
66+
h=h,
67+
g=g,
68+
scale=scale,
69+
cu_seqlens=cu_seqlens,
70+
)
71+
if SUPPRESS_LEVEL < 3:
72+
return g, o, A, final_state, None, None, None
73+
elif SUPPRESS_LEVEL >= 3:
74+
return g, o, A, final_state, w, h, v_new
75+
76+
77+
class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
78+
79+
@staticmethod
80+
@input_guard
81+
def forward(ctx,
82+
q: torch.Tensor,
83+
k: torch.Tensor,
84+
v: torch.Tensor,
85+
g: torch.Tensor,
86+
beta: torch.Tensor,
87+
scale: float,
88+
initial_state: torch.Tensor,
89+
output_final_state: bool,
90+
cu_seqlens: Optional[torch.LongTensor] = None,
91+
use_qk_l2norm_in_kernel: bool = False):
92+
if use_qk_l2norm_in_kernel:
93+
q = l2norm_fwd(q)
94+
k = l2norm_fwd(k)
95+
96+
g, o, A, final_state, w, h, v_new = chunk_gated_delta_rule_fwd(
97+
q=q,
98+
k=k,
99+
v=v,
100+
g=g,
101+
beta=beta,
102+
scale=scale,
103+
initial_state=initial_state,
104+
output_final_state=output_final_state,
105+
cu_seqlens=cu_seqlens,
106+
)
107+
ctx.scale = scale
108+
ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel
109+
return o.to(q.dtype), final_state
110+
111+
112+
@torch.compiler.disable
113+
def chunk_gated_delta_rule(q: torch.Tensor,
114+
k: torch.Tensor,
115+
v: torch.Tensor,
116+
g: torch.Tensor,
117+
beta: torch.Tensor,
118+
scale: float = None,
119+
initial_state: torch.Tensor = None,
120+
output_final_state: bool = False,
121+
cu_seqlens: Optional[torch.LongTensor] = None,
122+
head_first: bool = False,
123+
use_qk_l2norm_in_kernel: bool = False):
124+
r"""
125+
Args:
126+
q (torch.Tensor):
127+
queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
128+
k (torch.Tensor):
129+
keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
130+
v (torch.Tensor):
131+
values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
132+
g (torch.Tensor):
133+
(forget) gating tensor (in log space!) of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
134+
beta (torch.Tensor):
135+
betas of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
136+
scale (Optional[int]):
137+
Scale factor for the RetNet attention scores.
138+
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
139+
initial_state (Optional[torch.Tensor]):
140+
Initial state of shape `[N, H, K, V]` for `N` input sequences.
141+
For equal-length input sequences, `N` equals the batch size `B`.
142+
Default: `None`.
143+
output_final_state (Optional[bool]):
144+
Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
145+
cu_seqlens (torch.LongTensor):
146+
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
147+
consistent with the FlashAttention API.
148+
head_first (Optional[bool]):
149+
Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
150+
Default: `False`.
151+
152+
Returns:
153+
o (torch.Tensor):
154+
Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
155+
final_state (torch.Tensor):
156+
Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
157+
158+
Examples::
159+
>>> import torch
160+
>>> import torch.nn.functional as F
161+
>>> from einops import rearrange
162+
>>> from fla.ops.gated_delta_rule import chunk_gated_delta_rule
163+
# inputs with equal lengths
164+
>>> B, T, H, K, V = 4, 2048, 4, 512, 512
165+
>>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda')
166+
>>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1)
167+
>>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda')
168+
>>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid()
169+
>>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda'))
170+
>>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda')
171+
>>> o, ht = chunk_gated_delta_rule(
172+
q, k, v, g, beta,
173+
initial_state=h0,
174+
output_final_state=True
175+
)
176+
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
177+
>>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g))
178+
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
179+
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
180+
>>> o_var, ht_var = chunk_gated_delta_rule(
181+
q, k, v, g, beta,
182+
initial_state=h0,
183+
output_final_state=True,
184+
cu_seqlens=cu_seqlens
185+
)
186+
"""
187+
assert q.dtype == k.dtype == v.dtype
188+
assert q.dtype != torch.float32, "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16."
189+
assert len(
190+
beta.shape
191+
) == 3, "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise."
192+
193+
if head_first:
194+
raise DeprecationWarning(
195+
"head_first is deprecated and will be removed in a future version. "
196+
"Please use head_first=False for now instead.",
197+
stacklevel=2)
198+
q, k, v, beta, g = map(
199+
lambda x: rearrange(x, 'b h t ... -> b t h ...'),
200+
(q, k, v, beta, g))
201+
if not head_first and q.shape[1] < q.shape[2]:
202+
warnings.warn(
203+
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
204+
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
205+
"when head_first=False was specified. "
206+
"Please verify your input tensor format matches the expected shape [B, T, H, ...].",
207+
stacklevel=2)
208+
if cu_seqlens is not None:
209+
if q.shape[0] != 1:
210+
raise ValueError(
211+
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
212+
f"Please flatten variable-length inputs before processing.")
213+
if initial_state is not None and initial_state.shape[0] != len(
214+
cu_seqlens) - 1:
215+
raise ValueError(
216+
f"The number of initial states is expected to be equal to the number of input sequences, "
217+
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
218+
)
219+
if scale is None:
220+
scale = k.shape[-1]**-0.5
221+
o, final_state = ChunkGatedDeltaRuleFunction.apply(
222+
q, k, v, g, beta, scale, initial_state, output_final_state, cu_seqlens,
223+
use_qk_l2norm_in_kernel)
224+
if head_first:
225+
o = rearrange(o, 'b t h ... -> b h t ...')
226+
return o, final_state

0 commit comments

Comments
 (0)