2525from vllm .model_executor .models .llama_eagle3 import Eagle3LlamaForCausalLM
2626from vllm .multimodal import MULTIMODAL_REGISTRY
2727from vllm .platforms import current_platform
28+ from vllm .triton_utils import triton
2829from vllm .utils .platform_utils import is_pin_memory_available
2930from vllm .v1 .attention .backends .flash_attn import FlashAttentionMetadata
3031from vllm .v1 .attention .backends .tree_attn import (
4041from vllm .v1 .sample .metadata import SamplingMetadata
4142from vllm .v1 .sample .sampler import _SAMPLING_EPS
4243from vllm .v1 .spec_decode .metadata import SpecDecodeMetadata
44+ from vllm .v1 .spec_decode .utils import (
45+ eagle_prepare_inputs_padded_kernel ,
46+ eagle_prepare_next_token_padded_kernel ,
47+ )
4348from vllm .v1 .utils import CpuGpuBuffer
4449from vllm .v1 .worker .dp_utils import coordinate_batch_across_dp
4550from vllm .v1 .worker .gpu_input_batch import CachedRequestState , InputBatch
@@ -555,20 +560,15 @@ def prepare_next_token_ids_padded(
555560 sampled_token_ids : torch .Tensor ,
556561 requests : dict [str , CachedRequestState ],
557562 gpu_input_batch : InputBatch ,
558- discard_request_indices : torch .Tensor ,
559- num_discarded_requests : int ,
563+ discard_request_mask : torch .Tensor ,
560564 ) -> tuple [torch .Tensor , torch .Tensor ]:
561565 """
562566 This function is used to prepare the inputs for speculative decoding.
563567 It calculates the next token ids and the number of valid sampled tokens
564568 for each request, considering the "discarded" requests whose next token
565- is not sampled and comes from `request.get_token_id()` instead.
566- It also accounts for the rejected tokens in `sampled_token_ids`.
567- This function must use device functions to operate on the inputs, and
568- should not introduce any blocking CPU-GPU synchronization.
569+ is not sampled and comes from `request.get_token_id()` instead. This is denoted
570+ the "backup" token id. It also counts rejected tokens via `sampled_token_ids`.
569571 """
570- # TODO(Ben): Combine this into a custom fused kernel
571-
572572 # Precompute get_token_id for when there is no valid next token
573573 num_reqs = gpu_input_batch .num_reqs
574574 self .backup_next_token_ids .np [:num_reqs ] = np .array (
@@ -577,44 +577,39 @@ def prepare_next_token_ids_padded(
577577 common_attn_metadata .seq_lens_cpu [i ].item ()
578578 )
579579 for i in range (num_reqs )
580- ]
580+ ],
581+ dtype = np .int32 ,
581582 )
582583 self .backup_next_token_ids .copy_to_gpu (num_reqs )
584+ backup_tokens_gpu = self .backup_next_token_ids .gpu
583585
584- # Mask out the sampled tokens indices that should not be sampled.
585- discard_sampled_tokens_req_indices = discard_request_indices [
586- :num_discarded_requests
587- ]
586+ batch_size , num_tokens = sampled_token_ids .shape
587+ device = sampled_token_ids .device
588588
589- valid_sampled_token_ids_gpu = sampled_token_ids .clone ()
590- valid_sampled_token_ids_gpu .index_fill_ (
591- 0 , discard_sampled_tokens_req_indices , - 1
592- )
589+ assert discard_request_mask .dtype == torch .bool
590+ assert backup_tokens_gpu .dtype == torch .int32
593591
594- # Generate a mask for all valid tokens within those requests
595- valid_mask = ( valid_sampled_token_ids_gpu != - 1 ) & (
596- valid_sampled_token_ids_gpu < gpu_input_batch . vocab_size
592+ next_token_ids = torch . empty (( batch_size ,), dtype = torch . int32 , device = device )
593+ valid_sampled_tokens_count = torch . empty (
594+ ( batch_size ,), dtype = torch . int32 , device = device
597595 )
598596
599- # Count the number of valid tokens in each request
600- valid_sampled_tokens_count = valid_mask . sum ( dim = 1 )
597+ # Kernel grid: one program per request (row)
598+ grid = ( batch_size , )
601599
602- # Get the rightmost valid index per row
603- last_valid_indices = valid_sampled_tokens_count - 1
604- last_valid_indices_safe = torch .clamp (last_valid_indices , min = 0 )
605-
606- # Get last valid token from each row
607- # (assume undefined state where there is no valid token)
608- selected_tokens = torch .gather (
609- valid_sampled_token_ids_gpu , 1 , last_valid_indices_safe .unsqueeze (1 )
610- ).squeeze (1 )
611-
612- # Use last token if valid, pre-computed backup if not
613- batch_size = valid_sampled_token_ids_gpu .shape [0 ]
614- next_token_ids = torch .where (
615- last_valid_indices != - 1 ,
616- selected_tokens ,
617- self .backup_next_token_ids .gpu [:batch_size ],
600+ # Find the next power of 2 for block sizes
601+ BLOCK_SIZE_TOKENS = triton .next_power_of_2 (num_tokens )
602+ eagle_prepare_next_token_padded_kernel [grid ](
603+ sampled_token_ids ,
604+ discard_request_mask ,
605+ backup_tokens_gpu ,
606+ next_token_ids ,
607+ valid_sampled_tokens_count ,
608+ gpu_input_batch .vocab_size ,
609+ num_tokens ,
610+ batch_size ,
611+ sampled_token_ids .stride (0 ),
612+ BLOCK_SIZE_TOKENS = BLOCK_SIZE_TOKENS ,
618613 )
619614
620615 return next_token_ids , valid_sampled_tokens_count
@@ -624,35 +619,35 @@ def prepare_inputs_padded(
624619 common_attn_metadata : CommonAttentionMetadata ,
625620 spec_decode_metadata : SpecDecodeMetadata ,
626621 valid_sampled_tokens_count : torch .Tensor ,
627- ) -> tuple [CommonAttentionMetadata , torch .Tensor , torch . Tensor ]:
622+ ) -> tuple [CommonAttentionMetadata , torch .Tensor ]:
628623 """
629624 This function is used to prepare the inputs for speculative decoding
630625 It updates the common_attn_metadata for speculative decoding,
631626 but does not consider the rejected tokens. Instead, all tokens
632627 are included as inputs to the speculator, with the rejected tokens
633628 used as padding and filtered out later by `token_indices_to_sample`.
634- No blocking CPU operations should be introduced in this function.
635629 """
636- num_draft_tokens_gpu = torch .cat (
637- [
638- spec_decode_metadata .cu_num_draft_tokens [0 :1 ],
639- spec_decode_metadata .cu_num_draft_tokens [1 :]
640- - spec_decode_metadata .cu_num_draft_tokens [:- 1 ],
641- ]
630+ num_reqs = common_attn_metadata .num_reqs
631+ device = valid_sampled_tokens_count .device
632+
633+ token_indices_to_sample = torch .empty (
634+ (num_reqs ,), dtype = torch .int32 , device = device
642635 )
643636
644- num_rejected_tokens_gpu = torch .where (
645- num_draft_tokens_gpu > 0 ,
646- num_draft_tokens_gpu + 1 - valid_sampled_tokens_count ,
647- torch .zeros_like (num_draft_tokens_gpu ),
637+ # Kernel grid: one program per request (row)
638+ grid = (num_reqs ,)
639+ eagle_prepare_inputs_padded_kernel [grid ](
640+ spec_decode_metadata .cu_num_draft_tokens ,
641+ valid_sampled_tokens_count ,
642+ common_attn_metadata .query_start_loc ,
643+ token_indices_to_sample ,
644+ num_reqs ,
648645 )
649646
650647 query_start_loc_cpu = common_attn_metadata .query_start_loc_cpu
651-
652648 new_query_len_per_req = query_start_loc_cpu [1 :] - query_start_loc_cpu [:- 1 ]
653649
654650 total_num_tokens = query_start_loc_cpu [- 1 ].item ()
655- token_indices = self .arange [:total_num_tokens ]
656651
657652 spec_common_attn_metadata = CommonAttentionMetadata (
658653 query_start_loc = common_attn_metadata .query_start_loc ,
@@ -665,16 +660,12 @@ def prepare_inputs_padded(
665660 max_query_len = new_query_len_per_req .max ().item (),
666661 max_seq_len = common_attn_metadata .seq_lens_cpu .max ().item (),
667662 block_table_tensor = common_attn_metadata .block_table_tensor ,
668- slot_mapping = common_attn_metadata .slot_mapping [token_indices ],
663+ slot_mapping = common_attn_metadata .slot_mapping [: total_num_tokens ],
669664 causal = True ,
670665 dcp_local_seq_lens = common_attn_metadata .dcp_local_seq_lens ,
671666 )
672667
673- token_indices_to_sample = (
674- common_attn_metadata .query_start_loc [1 :] - 1 - num_rejected_tokens_gpu
675- )
676-
677- return spec_common_attn_metadata , token_indices , token_indices_to_sample
668+ return spec_common_attn_metadata , token_indices_to_sample
678669
679670 def propose_tree (
680671 self ,
0 commit comments