Skip to content

Commit ba8d6a2

Browse files
committed
Revert "[Bugfix] Fix error where vLLM expects numpy sampled token ids (#1119)"
This reverts commit 45edde6. Signed-off-by: Lihao Ran <imlihao.ran@gmail.com>
1 parent 1315868 commit ba8d6a2

File tree

1 file changed

+8
-12
lines changed

1 file changed

+8
-12
lines changed

tpu_inference/runner/tpu_runner.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from vllm.v1.kv_cache_interface import KVCacheConfig
2929
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
3030
DraftTokenIds, KVConnectorOutput, LogprobsLists,
31-
LogprobsTensors, ModelRunnerOutput)
31+
ModelRunnerOutput)
3232
from vllm.v1.request import Request
3333
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
3434
from vllm.v1.worker.kv_connector_model_runner_mixin import \
@@ -122,10 +122,9 @@ def get_output(self) -> ModelRunnerOutput:
122122
next_tokens_cpu = next_tokens_cpu[self.logits_indices_selector]
123123
selected_token_ids = np.expand_dims(next_tokens_cpu[:self._num_reqs],
124124
1)
125-
126-
valid_sampled_token_ids = [token_id for token_id in selected_token_ids]
125+
valid_sampled_token_ids = selected_token_ids.tolist()
127126
for i in self._discard_sampled_tokens_req_indices:
128-
valid_sampled_token_ids[i] = np.array([])
127+
valid_sampled_token_ids[i].clear()
129128
self._model_runner_output.sampled_token_ids = valid_sampled_token_ids
130129
return self._model_runner_output
131130

@@ -614,11 +613,11 @@ def _modify_prev_results(self):
614613
next_tokens_cpu = next_tokens_cpu[pre_logits_indices_selector]
615614
selected_token_ids = np.expand_dims(next_tokens_cpu[:len(pre_req_ids)],
616615
1)
617-
valid_sampled_token_ids = [token_id for token_id in selected_token_ids]
616+
valid_sampled_token_ids = selected_token_ids.tolist()
618617

619618
# Mask out the sampled tokens that should not be sampled.
620619
for i in pre_discard_sampled_tokens_req_indices:
621-
valid_sampled_token_ids[i] = np.array([])
620+
valid_sampled_token_ids[i].clear()
622621
# Append sampled tokens
623622
for pre_req_idx, req_state, _ in pre_request_seq_lens:
624623
sampled_ids = valid_sampled_token_ids[pre_req_idx]
@@ -940,9 +939,7 @@ def _sample_from_logits(
940939
if logits_indices_selector is not None:
941940
next_tokens = next_tokens[logits_indices_selector]
942941
selected_token_ids = np.expand_dims(next_tokens[:num_reqs], 1)
943-
valid_sampled_token_ids = [
944-
token_id for token_id in selected_token_ids
945-
]
942+
valid_sampled_token_ids = selected_token_ids.tolist()
946943
else:
947944
valid_sampled_token_ids = self.rejection_sampler.parse_output(
948945
next_tokens, self.input_batch.vocab_size,
@@ -951,7 +948,7 @@ def _sample_from_logits(
951948

952949
# Mask out the sampled tokens that should not be sampled.
953950
for i in discard_sampled_tokens_req_indices:
954-
valid_sampled_token_ids[i] = np.array([])
951+
valid_sampled_token_ids[i].clear()
955952
# Append sampled tokens
956953
for req_idx, req_state, _ in request_seq_lens:
957954
sampled_ids = valid_sampled_token_ids[req_idx]
@@ -1018,8 +1015,7 @@ def select_local_fn(local_array, local_indices):
10181015

10191016
@staticmethod
10201017
@functools.partial(jax.jit, static_argnames=("max_logprobs", ))
1021-
def _compute_and_gather_logprobs(logits, next_tokens,
1022-
max_logprobs) -> LogprobsTensors:
1018+
def _compute_and_gather_logprobs(logits, next_tokens, max_logprobs):
10231019
logprobs = compute_logprobs(logits)
10241020
return gather_logprobs(logprobs, next_tokens, max_logprobs)
10251021

0 commit comments

Comments
 (0)