Skip to content

Commit d4d2909

Browse files
committed
[Bugfix] Fix error where vLLM expects numpy sampled token ids
Signed-off-by: Kyuyeun Kim <kyuyeunk@google.com>
1 parent fdeb5de commit d4d2909

File tree

1 file changed

+21
-10
lines changed

1 file changed

+21
-10
lines changed

tpu_inference/runner/tpu_runner.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from flax import nnx
1616
from jax.experimental import mesh_utils
1717
from jax.sharding import NamedSharding, PartitionSpec
18-
from torchax.ops.mappings import j2t_dtype
18+
from torchax.ops.mappings import j2t, j2t_dtype
1919
from vllm.config import VllmConfig
2020
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
2121
has_kv_transfer_group)
@@ -28,7 +28,7 @@
2828
from vllm.v1.kv_cache_interface import KVCacheConfig
2929
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
3030
DraftTokenIds, KVConnectorOutput, LogprobsLists,
31-
ModelRunnerOutput)
31+
LogprobsTensors, ModelRunnerOutput)
3232
from vllm.v1.request import Request
3333
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
3434
from vllm.v1.worker.kv_connector_model_runner_mixin import \
@@ -122,7 +122,8 @@ def get_output(self) -> ModelRunnerOutput:
122122
next_tokens_cpu = next_tokens_cpu[self.logits_indices_selector]
123123
selected_token_ids = np.expand_dims(next_tokens_cpu[:self._num_reqs],
124124
1)
125-
valid_sampled_token_ids = selected_token_ids.tolist()
125+
126+
valid_sampled_token_ids = [token_id for token_id in selected_token_ids]
126127
for i in self._discard_sampled_tokens_req_indices:
127128
valid_sampled_token_ids[i].clear()
128129
self._model_runner_output.sampled_token_ids = valid_sampled_token_ids
@@ -190,7 +191,8 @@ def _substitute_placeholder_token(
190191
return input_ids.at[token_in_tpu_cur_input_indices].set(update_values)
191192

192193

193-
def _reorder_logits_indices(logprobs_lists, logits_indices_selector):
194+
def _reorder_logits_indices(logprobs_lists: LogprobsLists,
195+
logits_indices_selector: List[int]):
194196
return LogprobsLists(
195197
logprob_token_ids=[
196198
logprobs_lists.logprob_token_ids[i]
@@ -595,7 +597,7 @@ def _modify_prev_results(self):
595597
next_tokens_cpu = next_tokens_cpu[pre_logits_indices_selector]
596598
selected_token_ids = np.expand_dims(next_tokens_cpu[:len(pre_req_ids)],
597599
1)
598-
valid_sampled_token_ids = selected_token_ids.tolist()
600+
valid_sampled_token_ids = [token_id for token_id in selected_token_ids]
599601

600602
# Mask out the sampled tokens that should not be sampled.
601603
for i in pre_discard_sampled_tokens_req_indices:
@@ -898,7 +900,9 @@ def _sample_from_logits(
898900
if logits_indices_selector is not None:
899901
next_tokens = next_tokens[logits_indices_selector]
900902
selected_token_ids = np.expand_dims(next_tokens[:num_reqs], 1)
901-
valid_sampled_token_ids = selected_token_ids.tolist()
903+
valid_sampled_token_ids = [
904+
token_id for token_id in selected_token_ids
905+
]
902906
else:
903907
valid_sampled_token_ids = self.rejection_sampler.parse_output(
904908
next_tokens, self.input_batch.vocab_size,
@@ -975,10 +979,17 @@ def select_local_fn(local_array, local_indices):
975979
return ret
976980

977981
@staticmethod
978-
@functools.partial(jax.jit, static_argnames=("max_logprobs", ))
979-
def _compute_and_gather_logprobs(logits, next_tokens, max_logprobs):
980-
logprobs = compute_logprobs(logits)
981-
return gather_logprobs(logprobs, next_tokens, max_logprobs)
982+
def _compute_and_gather_logprobs(logits, next_tokens,
983+
max_logprobs) -> LogprobsTensors:
984+
985+
@functools.partial(jax.jit, static_argnames=("max_logprobs", ))
986+
def jit_compute_and_gather_logprobs(logits, next_tokens, max_logprobs):
987+
logprobs = compute_logprobs(logits)
988+
return gather_logprobs(logprobs, next_tokens, max_logprobs)
989+
990+
logprobs = jit_compute_and_gather_logprobs(logits, next_tokens,
991+
max_logprobs)
992+
return jax.tree.map(lambda x: j2t(x.astype(jnp.float32)), logprobs)
982993

983994
def _prepare_dp_input_metadata(self,
984995
scheduler_output: "VllmSchedulerOutput"):

0 commit comments

Comments
 (0)