Skip to content

Commit 1986de1

Browse files
authored
[Perf] Optimize EAGLE prepare_inputs_padded with triton kernels (#28597)
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com> Signed-off-by: Benjamin Chislett <chislett.ben@gmail.com>
1 parent 3461e7e commit 1986de1

File tree

4 files changed

+198
-107
lines changed

4 files changed

+198
-107
lines changed

tests/v1/spec_decode/test_eagle.py

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -103,16 +103,23 @@ def test_prepare_next_token_ids():
103103
mock_request.num_computed_tokens = 0
104104
mock_requests[req_id] = mock_request
105105

106+
# explicitly discard the last request
107+
discarded_req_mask = torch.tensor(
108+
[False, False, False, True], dtype=torch.bool, device=device
109+
)
106110
sampled_token_ids = [
107111
[0, 1, -1, -1, -1], # 1 accepted, 3 rejected, "1" sampled
108112
[0, 1, 2, 3, 4], # all accepted, "4" sampled
109113
[-1, -1, -1, -1, -1], # sampling skipped, use backup token "30"
110-
[-1, -1, -1, -1, -1], # this request will be discarded
114+
[0, 1, 2, -1, -1], # explicitly discarded, sampling should be ignored
111115
]
112116
sampled_token_ids_tensor = torch.tensor(
113117
sampled_token_ids, dtype=torch.int32, device=device
114118
)
115119
sampled_token_ids_cpu = [[i for i in seq if i != -1] for seq in sampled_token_ids]
120+
for i in range(len(sampled_token_ids_cpu)):
121+
if discarded_req_mask[i]:
122+
sampled_token_ids_cpu[i] = []
116123

117124
expected_next_token_ids_cpu = [1, 4, 30, 40]
118125
expected_next_token_ids_tensor = torch.tensor(
@@ -136,9 +143,6 @@ def test_prepare_next_token_ids():
136143
device=device,
137144
)
138145

139-
discarded_req_indices = torch.tensor([3], dtype=torch.int64, device=device)
140-
num_discarded_reqs = 1
141-
142146
expected_valid_sampled_tokens_count = torch.tensor(
143147
[2, 5, 0, 0], dtype=torch.int32, device=device
144148
)
@@ -149,8 +153,7 @@ def test_prepare_next_token_ids():
149153
sampled_token_ids_tensor,
150154
mock_requests,
151155
mock_input_batch,
152-
discarded_req_indices,
153-
num_discarded_reqs,
156+
discarded_req_mask,
154157
)
155158
)
156159

@@ -256,21 +259,13 @@ def test_prepare_inputs_padded():
256259
- Request 3: query_len = 3, rejected = 2
257260
258261
Expected outputs:
259-
token_indices: [0, 1, 2,
260-
3, 4, 5,
261-
6, 7, 8]
262-
Reason: Deferred computation should not disturb the original indices.
263-
264262
token_indices_to_sample: [1, 5, 6]
265263
Reason: After accounting for rejections, these are the valid token positions
266264
from the original indices to sample from.
267265
"""
268266

269267
device = torch.device(current_platform.device_type)
270268

271-
expected_token_indices = torch.tensor(
272-
[0, 1, 2, 3, 4, 5, 6, 7, 8], dtype=torch.int32, device=device
273-
)
274269
expected_token_indices_to_sample = torch.tensor(
275270
[1, 5, 6], dtype=torch.int32, device=device
276271
)
@@ -305,15 +300,12 @@ def test_prepare_inputs_padded():
305300

306301
proposer = _create_proposer("eagle", num_speculative_tokens)
307302

308-
output_metadata, token_indices, token_indices_to_sample = (
309-
proposer.prepare_inputs_padded(
310-
common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count
311-
)
303+
output_metadata, token_indices_to_sample = proposer.prepare_inputs_padded(
304+
common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count
312305
)
313306

314307
assert output_metadata.max_query_len == 3
315308
assert torch.equal(output_metadata.query_start_loc, expected_query_start_loc)
316-
assert torch.equal(token_indices, expected_token_indices)
317309
assert torch.equal(token_indices_to_sample, expected_token_indices_to_sample)
318310

319311

vllm/v1/spec_decode/eagle.py

Lines changed: 49 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
2626
from vllm.multimodal import MULTIMODAL_REGISTRY
2727
from vllm.platforms import current_platform
28+
from vllm.triton_utils import triton
2829
from vllm.utils.platform_utils import is_pin_memory_available
2930
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
3031
from vllm.v1.attention.backends.tree_attn import (
@@ -40,6 +41,10 @@
4041
from vllm.v1.sample.metadata import SamplingMetadata
4142
from vllm.v1.sample.sampler import _SAMPLING_EPS
4243
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
44+
from vllm.v1.spec_decode.utils import (
45+
eagle_prepare_inputs_padded_kernel,
46+
eagle_prepare_next_token_padded_kernel,
47+
)
4348
from vllm.v1.utils import CpuGpuBuffer
4449
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
4550
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
@@ -555,20 +560,15 @@ def prepare_next_token_ids_padded(
555560
sampled_token_ids: torch.Tensor,
556561
requests: dict[str, CachedRequestState],
557562
gpu_input_batch: InputBatch,
558-
discard_request_indices: torch.Tensor,
559-
num_discarded_requests: int,
563+
discard_request_mask: torch.Tensor,
560564
) -> tuple[torch.Tensor, torch.Tensor]:
561565
"""
562566
This function is used to prepare the inputs for speculative decoding.
563567
It calculates the next token ids and the number of valid sampled tokens
564568
for each request, considering the "discarded" requests whose next token
565-
is not sampled and comes from `request.get_token_id()` instead.
566-
It also accounts for the rejected tokens in `sampled_token_ids`.
567-
This function must use device functions to operate on the inputs, and
568-
should not introduce any blocking CPU-GPU synchronization.
569+
is not sampled and comes from `request.get_token_id()` instead. This is denoted
570+
the "backup" token id. It also counts rejected tokens via `sampled_token_ids`.
569571
"""
570-
# TODO(Ben): Combine this into a custom fused kernel
571-
572572
# Precompute get_token_id for when there is no valid next token
573573
num_reqs = gpu_input_batch.num_reqs
574574
self.backup_next_token_ids.np[:num_reqs] = np.array(
@@ -577,44 +577,39 @@ def prepare_next_token_ids_padded(
577577
common_attn_metadata.seq_lens_cpu[i].item()
578578
)
579579
for i in range(num_reqs)
580-
]
580+
],
581+
dtype=np.int32,
581582
)
582583
self.backup_next_token_ids.copy_to_gpu(num_reqs)
584+
backup_tokens_gpu = self.backup_next_token_ids.gpu
583585

584-
# Mask out the sampled tokens indices that should not be sampled.
585-
discard_sampled_tokens_req_indices = discard_request_indices[
586-
:num_discarded_requests
587-
]
586+
batch_size, num_tokens = sampled_token_ids.shape
587+
device = sampled_token_ids.device
588588

589-
valid_sampled_token_ids_gpu = sampled_token_ids.clone()
590-
valid_sampled_token_ids_gpu.index_fill_(
591-
0, discard_sampled_tokens_req_indices, -1
592-
)
589+
assert discard_request_mask.dtype == torch.bool
590+
assert backup_tokens_gpu.dtype == torch.int32
593591

594-
# Generate a mask for all valid tokens within those requests
595-
valid_mask = (valid_sampled_token_ids_gpu != -1) & (
596-
valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size
592+
next_token_ids = torch.empty((batch_size,), dtype=torch.int32, device=device)
593+
valid_sampled_tokens_count = torch.empty(
594+
(batch_size,), dtype=torch.int32, device=device
597595
)
598596

599-
# Count the number of valid tokens in each request
600-
valid_sampled_tokens_count = valid_mask.sum(dim=1)
597+
# Kernel grid: one program per request (row)
598+
grid = (batch_size,)
601599

602-
# Get the rightmost valid index per row
603-
last_valid_indices = valid_sampled_tokens_count - 1
604-
last_valid_indices_safe = torch.clamp(last_valid_indices, min=0)
605-
606-
# Get last valid token from each row
607-
# (assume undefined state where there is no valid token)
608-
selected_tokens = torch.gather(
609-
valid_sampled_token_ids_gpu, 1, last_valid_indices_safe.unsqueeze(1)
610-
).squeeze(1)
611-
612-
# Use last token if valid, pre-computed backup if not
613-
batch_size = valid_sampled_token_ids_gpu.shape[0]
614-
next_token_ids = torch.where(
615-
last_valid_indices != -1,
616-
selected_tokens,
617-
self.backup_next_token_ids.gpu[:batch_size],
600+
# Find the next power of 2 for block sizes
601+
BLOCK_SIZE_TOKENS = triton.next_power_of_2(num_tokens)
602+
eagle_prepare_next_token_padded_kernel[grid](
603+
sampled_token_ids,
604+
discard_request_mask,
605+
backup_tokens_gpu,
606+
next_token_ids,
607+
valid_sampled_tokens_count,
608+
gpu_input_batch.vocab_size,
609+
num_tokens,
610+
batch_size,
611+
sampled_token_ids.stride(0),
612+
BLOCK_SIZE_TOKENS=BLOCK_SIZE_TOKENS,
618613
)
619614

620615
return next_token_ids, valid_sampled_tokens_count
@@ -624,35 +619,35 @@ def prepare_inputs_padded(
624619
common_attn_metadata: CommonAttentionMetadata,
625620
spec_decode_metadata: SpecDecodeMetadata,
626621
valid_sampled_tokens_count: torch.Tensor,
627-
) -> tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]:
622+
) -> tuple[CommonAttentionMetadata, torch.Tensor]:
628623
"""
629624
This function is used to prepare the inputs for speculative decoding
630625
It updates the common_attn_metadata for speculative decoding,
631626
but does not consider the rejected tokens. Instead, all tokens
632627
are included as inputs to the speculator, with the rejected tokens
633628
used as padding and filtered out later by `token_indices_to_sample`.
634-
No blocking CPU operations should be introduced in this function.
635629
"""
636-
num_draft_tokens_gpu = torch.cat(
637-
[
638-
spec_decode_metadata.cu_num_draft_tokens[0:1],
639-
spec_decode_metadata.cu_num_draft_tokens[1:]
640-
- spec_decode_metadata.cu_num_draft_tokens[:-1],
641-
]
630+
num_reqs = common_attn_metadata.num_reqs
631+
device = valid_sampled_tokens_count.device
632+
633+
token_indices_to_sample = torch.empty(
634+
(num_reqs,), dtype=torch.int32, device=device
642635
)
643636

644-
num_rejected_tokens_gpu = torch.where(
645-
num_draft_tokens_gpu > 0,
646-
num_draft_tokens_gpu + 1 - valid_sampled_tokens_count,
647-
torch.zeros_like(num_draft_tokens_gpu),
637+
# Kernel grid: one program per request (row)
638+
grid = (num_reqs,)
639+
eagle_prepare_inputs_padded_kernel[grid](
640+
spec_decode_metadata.cu_num_draft_tokens,
641+
valid_sampled_tokens_count,
642+
common_attn_metadata.query_start_loc,
643+
token_indices_to_sample,
644+
num_reqs,
648645
)
649646

650647
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
651-
652648
new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
653649

654650
total_num_tokens = query_start_loc_cpu[-1].item()
655-
token_indices = self.arange[:total_num_tokens]
656651

657652
spec_common_attn_metadata = CommonAttentionMetadata(
658653
query_start_loc=common_attn_metadata.query_start_loc,
@@ -665,16 +660,12 @@ def prepare_inputs_padded(
665660
max_query_len=new_query_len_per_req.max().item(),
666661
max_seq_len=common_attn_metadata.seq_lens_cpu.max().item(),
667662
block_table_tensor=common_attn_metadata.block_table_tensor,
668-
slot_mapping=common_attn_metadata.slot_mapping[token_indices],
663+
slot_mapping=common_attn_metadata.slot_mapping[:total_num_tokens],
669664
causal=True,
670665
dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens,
671666
)
672667

673-
token_indices_to_sample = (
674-
common_attn_metadata.query_start_loc[1:] - 1 - num_rejected_tokens_gpu
675-
)
676-
677-
return spec_common_attn_metadata, token_indices, token_indices_to_sample
668+
return spec_common_attn_metadata, token_indices_to_sample
678669

679670
def propose_tree(
680671
self,

vllm/v1/spec_decode/utils.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
from vllm.sampling_params import SamplingParams
4+
from vllm.triton_utils import tl, triton
45

56
_SAMPLING_EPS = 1e-5
67

@@ -14,3 +15,107 @@ def is_spec_decode_unsupported(sampling_params: SamplingParams) -> bool:
1415
or sampling_params.min_p > _SAMPLING_EPS
1516
or sampling_params.logprobs is not None
1617
)
18+
19+
20+
@triton.jit
21+
def eagle_prepare_inputs_padded_kernel(
22+
cu_num_draft_tokens_ptr, # [num_reqs]
23+
valid_sampled_tokens_count_ptr, # [num_reqs]
24+
query_start_loc_gpu_ptr, # [num_reqs + 1]
25+
token_indices_to_sample_ptr, # [num_reqs] (output)
26+
num_reqs, # tl.int32
27+
):
28+
"""
29+
Fused kernel for Eagle prepare_input_padded. This kernel computes the
30+
token index to sample for each request, taking into account the number
31+
of draft tokens and the number of valid sampled tokens (which is one more than
32+
the number of accepted tokens).
33+
"""
34+
req_idx = tl.program_id(axis=0)
35+
if req_idx >= num_reqs:
36+
return
37+
38+
# Calculate num_draft_tokens from cu_num_draft_tokens, which is an inclusive
39+
# cumulative sum (first entry is the first value, not zero).
40+
cu_draft_curr = tl.load(cu_num_draft_tokens_ptr + req_idx)
41+
42+
num_draft_tokens = 0
43+
if req_idx == 0:
44+
num_draft_tokens = cu_draft_curr
45+
else:
46+
cu_draft_prev = tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
47+
num_draft_tokens = cu_draft_curr - cu_draft_prev
48+
49+
valid_count = tl.load(valid_sampled_tokens_count_ptr + req_idx)
50+
num_rejected_tokens = num_draft_tokens + 1 - valid_count
51+
num_rejected_tokens = tl.where(num_draft_tokens > 0, num_rejected_tokens, 0)
52+
53+
# query_start_loc[req_idx + 1] is the start position of the next request,
54+
# which is one past the last token of this request.
55+
q_last_tok_idx = tl.load(query_start_loc_gpu_ptr + req_idx + 1) - 1
56+
57+
index_to_sample = q_last_tok_idx - num_rejected_tokens
58+
tl.store(token_indices_to_sample_ptr + req_idx, index_to_sample)
59+
60+
61+
@triton.jit
62+
def eagle_prepare_next_token_padded_kernel(
63+
sampled_token_ids_ptr, # [num_reqs, num_sampled_tokens_per_req]
64+
discard_request_mask_ptr, # [num_reqs]
65+
backup_next_token_ids_ptr, # [num_reqs]
66+
next_token_ids_ptr, # [num_reqs] (output)
67+
valid_sampled_tokens_count_ptr, # [num_reqs] (output)
68+
vocab_size, # tl.int32
69+
num_sampled_tokens_per_req, # tl.int32 (num_spec_tokens + 1)
70+
num_reqs, # tl.int32
71+
stride_sampled_token_ids, # tl.int32 (stride for dim 0)
72+
BLOCK_SIZE_TOKENS: tl.constexpr, # Power-of-2 >= num_sampled_tokens_per_req
73+
):
74+
"""
75+
Fused kernel for Eagle prepare_next_token_ids_padded. This kernel computes the
76+
number of valid (1 + accepted) tokens for each request, and the corresponding
77+
"next" token id to sample from during speculative decoding. This is the
78+
"last accepted token" from the sampled tokens, or the backup token if no
79+
tokens were accepted or if the request is marked as discarded.
80+
"""
81+
req_idx = tl.program_id(axis=0)
82+
if req_idx >= num_reqs:
83+
return
84+
85+
# Check if this request is discarded.
86+
is_discarded = tl.load(discard_request_mask_ptr + req_idx)
87+
88+
if is_discarded:
89+
backup_token = tl.load(backup_next_token_ids_ptr + req_idx)
90+
valid_count = tl.full((), 0, dtype=tl.uint32)
91+
tl.store(next_token_ids_ptr + req_idx, backup_token)
92+
tl.store(valid_sampled_tokens_count_ptr + req_idx, valid_count)
93+
else:
94+
# Count the number of valid tokens among the sampled tokens.
95+
token_offs = tl.arange(0, BLOCK_SIZE_TOKENS)
96+
token_mask = token_offs < num_sampled_tokens_per_req
97+
98+
row_ptr = sampled_token_ids_ptr + req_idx * stride_sampled_token_ids
99+
token_ids = tl.load(row_ptr + token_offs, mask=token_mask, other=-1)
100+
101+
# Rejected tokens are -1, valid tokens are in [0, vocab_size)
102+
is_valid_mask = (token_ids != -1) & (token_ids < vocab_size) & token_mask
103+
valid_count = tl.sum(is_valid_mask)
104+
105+
if valid_count > 0:
106+
# Guaranteed to be well-defined since
107+
# valid_count > 0 implies is_valid_mask is not empty
108+
last_valid_index = tl.max(tl.where(is_valid_mask, token_offs, -1))
109+
110+
# Select the token at that index, using a sum trick since
111+
# we don't want to load again to access token_ids[last_valid_index].
112+
last_valid_token = tl.sum(
113+
tl.where(token_offs == last_valid_index, token_ids, 0)
114+
)
115+
tl.store(next_token_ids_ptr + req_idx, last_valid_token)
116+
else:
117+
# No valid tokens found, use backup token
118+
backup_token = tl.load(backup_next_token_ids_ptr + req_idx)
119+
tl.store(next_token_ids_ptr + req_idx, backup_token)
120+
121+
tl.store(valid_sampled_tokens_count_ptr + req_idx, valid_count)

0 commit comments

Comments
 (0)