Skip to content

Commit 99b3a50

Browse files
authored
[Qwen3-Next][GDN] fixes cuda graph capturing bug in GDN metadata and a stride bug in causal_conv_1d. (vllm-project#25743)
Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com>
1 parent 6e30010 commit 99b3a50

File tree

3 files changed

+50
-45
lines changed

3 files changed

+50
-45
lines changed

vllm/model_executor/layers/mamba/ops/causal_conv1d.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def _causal_conv1d_fwd_kernel( # continuous batching
4141
stride_istate_seq: tl.constexpr,
4242
stride_istate_dim: tl.constexpr,
4343
stride_istate_token: tl.constexpr,
44+
stride_cache_indices: tl.constexpr,
4445
stride_o_seq: tl.constexpr,
4546
stride_o_dim: tl.constexpr,
4647
stride_o_token: tl.constexpr,
@@ -69,7 +70,7 @@ def _causal_conv1d_fwd_kernel( # continuous batching
6970
# rather than mixing sequences - to make updating initial_states across sequences efficiently
7071

7172
# single-sequence id
72-
idx_seq = tl.load(batch_ptr + tl.program_id(0))
73+
idx_seq = tl.load(batch_ptr + tl.program_id(0)).to(tl.int64)
7374
chunk_offset = tl.load(token_chunk_offset_ptr + tl.program_id(0))
7475

7576
# BLOCK_N elements along the feature-dimension (channel)
@@ -91,8 +92,9 @@ def _causal_conv1d_fwd_kernel( # continuous batching
9192

9293
if IS_CONTINUOUS_BATCHING:
9394
# cache_idx
94-
conv_state_batch_coord = tl.load(conv_state_indices_ptr + idx_seq).to(
95-
tl.int64)
95+
conv_state_batch_coord = tl.load(conv_state_indices_ptr +
96+
idx_seq * stride_cache_indices).to(
97+
tl.int64)
9698
else:
9799
# cache_idx
98100
conv_state_batch_coord = idx_seq
@@ -480,6 +482,8 @@ def causal_conv1d_fn(
480482
stride_o_seq = out.stride(0)
481483
stride_o_dim = out.stride(1)
482484
stride_o_token = out.stride(2)
485+
stride_cache_indices = cache_indices.stride(
486+
0) if cache_indices is not None else 0
483487

484488
if validate_data:
485489
assert x.dim() == 2
@@ -595,6 +599,7 @@ def grid(META):
595599
stride_istate_seq,
596600
stride_istate_dim,
597601
stride_istate_token,
602+
stride_cache_indices,
598603
stride_o_seq,
599604
stride_o_dim,
600605
stride_o_token,

vllm/v1/attention/backends/gdn_attn.py

Lines changed: 26 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -125,31 +125,33 @@ def build( # type: ignore[override]
125125
common_prefix_len: int,
126126
common_attn_metadata: CommonAttentionMetadata,
127127
num_accepted_tokens: Optional[torch.Tensor] = None,
128-
num_draft_tokens: Optional[torch.Tensor] = None,
128+
num_decode_draft_tokens_cpu: Optional[torch.Tensor] = None,
129129
fast_build: bool = False,
130130
) -> GDNAttentionMetadata:
131131
m = common_attn_metadata
132132

133133
query_start_loc = m.query_start_loc
134134
context_lens = m.num_computed_tokens_cpu
135135
context_lens_tensor = context_lens.to(query_start_loc.device)
136-
seq_lens_tensor = m.seq_lens
137136
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
138137

139-
if (not self.use_spec_decode or num_draft_tokens is None
140-
or num_draft_tokens.sum().item() == 0):
138+
if (not self.use_spec_decode or num_decode_draft_tokens_cpu is None
139+
or num_decode_draft_tokens_cpu[num_decode_draft_tokens_cpu >=
140+
0].sum().item() == 0):
141141
spec_sequence_masks = None
142+
num_spec_decodes = 0
142143
else:
143-
spec_sequence_masks = (num_draft_tokens > 0) & (
144-
context_lens_tensor +
145-
(num_draft_tokens + 1) == seq_lens_tensor)
146-
if spec_sequence_masks.sum().item() == 0:
144+
spec_sequence_masks = num_decode_draft_tokens_cpu >= 0
145+
num_spec_decodes = spec_sequence_masks.sum().item()
146+
if num_spec_decodes == 0:
147147
spec_sequence_masks = None
148+
else:
149+
spec_sequence_masks = spec_sequence_masks.to(
150+
query_start_loc.device, non_blocking=True)
148151

149152
if spec_sequence_masks is None:
150153
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
151154
split_decodes_and_prefills(m, decode_threshold=1))
152-
num_spec_decodes = 0
153155
num_spec_decode_tokens = 0
154156
spec_token_masks = None
155157
spec_state_indices_tensor = None
@@ -158,7 +160,6 @@ def build( # type: ignore[override]
158160
non_spec_query_start_loc = query_start_loc
159161
num_accepted_tokens = None
160162
else:
161-
num_spec_decodes = spec_sequence_masks.sum().item()
162163
query_lens = query_start_loc[1:] - query_start_loc[:-1]
163164

164165
non_spec_query_lens = query_lens[~spec_sequence_masks]
@@ -314,28 +315,18 @@ def build_for_cudagraph_capture(
314315
"""
315316
m = common_attn_metadata
316317

317-
assert (m.num_reqs * (self.num_spec + 1) <= m.num_actual_tokens
318-
and ((m.num_reqs + 1) * (self.num_spec + 1)
319-
>= m.num_actual_tokens)), \
320-
"GDN only supports decode-only full CUDAGraph capture. " \
321-
"Make sure all cudagraph capture sizes <= max_num_seq."
322-
323-
num_accepted_tokens = torch.full((m.num_reqs, ),
324-
m.max_query_len,
325-
dtype=torch.int32,
326-
device=m.query_start_loc.device)
327-
num_drafted_tokens = torch.full((m.num_reqs, ),
328-
self.num_spec,
329-
dtype=torch.int32,
330-
device=m.query_start_loc.device)
331-
332-
# Fixes query-start loc for spec-sequence-indices.
333-
m.query_start_loc = torch.arange(0,
334-
m.num_actual_tokens + 1,
335-
step=m.max_query_len,
336-
device=m.query_start_loc.device,
337-
dtype=torch.int32)
338-
m.num_computed_tokens_cpu = (m.seq_lens_cpu - torch.full(
339-
(m.num_reqs, ), m.max_query_len, dtype=torch.int32, device='cpu'))
340-
341-
return self.build(0, m, num_accepted_tokens, num_drafted_tokens)
318+
assert (
319+
m.num_reqs <= self.decode_cudagraph_max_bs
320+
and m.num_actual_tokens <= self.decode_cudagraph_max_bs), (
321+
f"GDN only supports decode-only full CUDAGraph capture. "
322+
f"Make sure batch size ({m.num_reqs}) <= "
323+
f"cudagraph capture sizes ({self.decode_cudagraph_max_bs}), "
324+
f"and number of tokens ({m.num_actual_tokens}) <= "
325+
f"cudagraph capture sizes ({self.decode_cudagraph_max_bs}).")
326+
327+
num_accepted_tokens = torch.diff(m.query_start_loc)
328+
num_decode_draft_tokens_cpu = (num_accepted_tokens - 1).cpu()
329+
m.num_computed_tokens_cpu = m.seq_lens_cpu - num_accepted_tokens.cpu()
330+
331+
return self.build(0, m, num_accepted_tokens,
332+
num_decode_draft_tokens_cpu)

vllm/v1/worker/gpu_model_runner.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -360,8 +360,8 @@ def __init__(
360360
dtype=torch.int64)
361361
self.num_discarded_requests = 0
362362

363-
self.num_draft_tokens = self._make_buffer(self.max_num_reqs,
364-
dtype=torch.int32)
363+
self.num_decode_draft_tokens = self._make_buffer(self.max_num_reqs,
364+
dtype=torch.int32)
365365
self.num_accepted_tokens = self._make_buffer(self.max_num_reqs,
366366
dtype=torch.int64)
367367

@@ -1103,17 +1103,25 @@ def _prepare_inputs(
11031103
# Iterate over the dictionary rather than all requests since not all
11041104
# requests have draft tokens.
11051105
num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
1106+
# For chunked prefills, use -1 as mask rather than 0, as guided
1107+
# decoding may rollback speculative tokens.
1108+
num_decode_draft_tokens = np.full(num_reqs, -1, dtype=np.int32)
11061109
for req_id, draft_token_ids in (
11071110
scheduler_output.scheduled_spec_decode_tokens.items()):
11081111
req_idx = self.input_batch.req_id_to_index[req_id]
11091112
num_draft_tokens[req_idx] = len(draft_token_ids)
1110-
1113+
num_decode_draft_tokens[req_idx] = (len(draft_token_ids) if (
1114+
self.input_batch.num_computed_tokens_cpu[req_idx]
1115+
>= self.input_batch.num_prompt_tokens[req_idx]) else -1)
11111116
spec_decode_metadata = self._calc_spec_decode_metadata(
11121117
num_draft_tokens, cu_num_tokens)
11131118
logits_indices = spec_decode_metadata.logits_indices
1114-
self.num_draft_tokens.np[:num_reqs] = num_draft_tokens
1115-
self.num_draft_tokens.np[num_reqs:].fill(0)
1116-
self.num_draft_tokens.copy_to_gpu()
1119+
1120+
# For DECODE only cuda graph of some attention backends (e.g., GDN).
1121+
self.num_decode_draft_tokens.np[:
1122+
num_reqs] = num_decode_draft_tokens
1123+
self.num_decode_draft_tokens.np[num_reqs:].fill(-1)
1124+
self.num_decode_draft_tokens.copy_to_gpu()
11171125

11181126
logits_indices_padded = None
11191127
if self.cache_config.kv_sharing_fast_prefill:
@@ -1217,7 +1225,8 @@ def _prepare_inputs(
12171225
extra_attn_metadata_args = dict(
12181226
num_accepted_tokens=self.num_accepted_tokens.
12191227
gpu[:num_reqs],
1220-
num_draft_tokens=self.num_draft_tokens.gpu[:num_reqs],
1228+
num_decode_draft_tokens_cpu=self.
1229+
num_decode_draft_tokens.cpu[:num_reqs],
12211230
)
12221231

12231232
if ubatch_slices is not None:

0 commit comments

Comments
 (0)