2828from vllm .v1 .kv_cache_interface import KVCacheConfig
2929from vllm .v1 .outputs import (EMPTY_MODEL_RUNNER_OUTPUT , AsyncModelRunnerOutput ,
3030 DraftTokenIds , KVConnectorOutput , LogprobsLists ,
31- LogprobsTensors , ModelRunnerOutput )
31+ ModelRunnerOutput )
3232from vllm .v1 .request import Request
3333from vllm .v1 .spec_decode .ngram_proposer import NgramProposer
3434from vllm .v1 .worker .kv_connector_model_runner_mixin import \
@@ -122,10 +122,9 @@ def get_output(self) -> ModelRunnerOutput:
122122 next_tokens_cpu = next_tokens_cpu [self .logits_indices_selector ]
123123 selected_token_ids = np .expand_dims (next_tokens_cpu [:self ._num_reqs ],
124124 1 )
125-
126- valid_sampled_token_ids = [token_id for token_id in selected_token_ids ]
125+ valid_sampled_token_ids = selected_token_ids .tolist ()
127126 for i in self ._discard_sampled_tokens_req_indices :
128- valid_sampled_token_ids [i ] = np . array ([] )
127+ valid_sampled_token_ids [i ]. clear ( )
129128 self ._model_runner_output .sampled_token_ids = valid_sampled_token_ids
130129 return self ._model_runner_output
131130
@@ -614,11 +613,11 @@ def _modify_prev_results(self):
614613 next_tokens_cpu = next_tokens_cpu [pre_logits_indices_selector ]
615614 selected_token_ids = np .expand_dims (next_tokens_cpu [:len (pre_req_ids )],
616615 1 )
617- valid_sampled_token_ids = [ token_id for token_id in selected_token_ids ]
616+ valid_sampled_token_ids = selected_token_ids . tolist ()
618617
619618 # Mask out the sampled tokens that should not be sampled.
620619 for i in pre_discard_sampled_tokens_req_indices :
621- valid_sampled_token_ids [i ] = np . array ([] )
620+ valid_sampled_token_ids [i ]. clear ( )
622621 # Append sampled tokens
623622 for pre_req_idx , req_state , _ in pre_request_seq_lens :
624623 sampled_ids = valid_sampled_token_ids [pre_req_idx ]
@@ -940,9 +939,7 @@ def _sample_from_logits(
940939 if logits_indices_selector is not None :
941940 next_tokens = next_tokens [logits_indices_selector ]
942941 selected_token_ids = np .expand_dims (next_tokens [:num_reqs ], 1 )
943- valid_sampled_token_ids = [
944- token_id for token_id in selected_token_ids
945- ]
942+ valid_sampled_token_ids = selected_token_ids .tolist ()
946943 else :
947944 valid_sampled_token_ids = self .rejection_sampler .parse_output (
948945 next_tokens , self .input_batch .vocab_size ,
@@ -951,7 +948,7 @@ def _sample_from_logits(
951948
952949 # Mask out the sampled tokens that should not be sampled.
953950 for i in discard_sampled_tokens_req_indices :
954- valid_sampled_token_ids [i ] = np . array ([] )
951+ valid_sampled_token_ids [i ]. clear ( )
955952 # Append sampled tokens
956953 for req_idx , req_state , _ in request_seq_lens :
957954 sampled_ids = valid_sampled_token_ids [req_idx ]
@@ -1018,8 +1015,7 @@ def select_local_fn(local_array, local_indices):
10181015
10191016 @staticmethod
10201017 @functools .partial (jax .jit , static_argnames = ("max_logprobs" , ))
1021- def _compute_and_gather_logprobs (logits , next_tokens ,
1022- max_logprobs ) -> LogprobsTensors :
1018+ def _compute_and_gather_logprobs (logits , next_tokens , max_logprobs ):
10231019 logprobs = compute_logprobs (logits )
10241020 return gather_logprobs (logprobs , next_tokens , max_logprobs )
10251021
0 commit comments