|
15 | 15 | from flax import nnx |
16 | 16 | from jax.experimental import mesh_utils |
17 | 17 | from jax.sharding import NamedSharding, PartitionSpec |
18 | | -from torchax.ops.mappings import j2t_dtype |
| 18 | +from torchax.ops.mappings import j2t, j2t_dtype |
19 | 19 | from vllm.config import VllmConfig |
20 | 20 | from vllm.distributed.kv_transfer import (get_kv_transfer_group, |
21 | 21 | has_kv_transfer_group) |
|
28 | 28 | from vllm.v1.kv_cache_interface import KVCacheConfig |
29 | 29 | from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, |
30 | 30 | DraftTokenIds, KVConnectorOutput, LogprobsLists, |
31 | | - ModelRunnerOutput) |
| 31 | + LogprobsTensors, ModelRunnerOutput) |
32 | 32 | from vllm.v1.request import Request |
33 | 33 | from vllm.v1.spec_decode.ngram_proposer import NgramProposer |
34 | 34 | from vllm.v1.worker.kv_connector_model_runner_mixin import \ |
@@ -122,7 +122,8 @@ def get_output(self) -> ModelRunnerOutput: |
122 | 122 | next_tokens_cpu = next_tokens_cpu[self.logits_indices_selector] |
123 | 123 | selected_token_ids = np.expand_dims(next_tokens_cpu[:self._num_reqs], |
124 | 124 | 1) |
125 | | - valid_sampled_token_ids = selected_token_ids.tolist() |
| 125 | + |
| 126 | + valid_sampled_token_ids = [token_id for token_id in selected_token_ids] |
126 | 127 | for i in self._discard_sampled_tokens_req_indices: |
127 | 128 | valid_sampled_token_ids[i].clear() |
128 | 129 | self._model_runner_output.sampled_token_ids = valid_sampled_token_ids |
@@ -190,7 +191,8 @@ def _substitute_placeholder_token( |
190 | 191 | return input_ids.at[token_in_tpu_cur_input_indices].set(update_values) |
191 | 192 |
|
192 | 193 |
|
193 | | -def _reorder_logits_indices(logprobs_lists, logits_indices_selector): |
| 194 | +def _reorder_logits_indices(logprobs_lists: LogprobsLists, |
| 195 | + logits_indices_selector: List[int]): |
194 | 196 | return LogprobsLists( |
195 | 197 | logprob_token_ids=[ |
196 | 198 | logprobs_lists.logprob_token_ids[i] |
@@ -595,7 +597,7 @@ def _modify_prev_results(self): |
595 | 597 | next_tokens_cpu = next_tokens_cpu[pre_logits_indices_selector] |
596 | 598 | selected_token_ids = np.expand_dims(next_tokens_cpu[:len(pre_req_ids)], |
597 | 599 | 1) |
598 | | - valid_sampled_token_ids = selected_token_ids.tolist() |
| 600 | + valid_sampled_token_ids = [token_id for token_id in selected_token_ids] |
599 | 601 |
|
600 | 602 | # Mask out the sampled tokens that should not be sampled. |
601 | 603 | for i in pre_discard_sampled_tokens_req_indices: |
@@ -898,7 +900,9 @@ def _sample_from_logits( |
898 | 900 | if logits_indices_selector is not None: |
899 | 901 | next_tokens = next_tokens[logits_indices_selector] |
900 | 902 | selected_token_ids = np.expand_dims(next_tokens[:num_reqs], 1) |
901 | | - valid_sampled_token_ids = selected_token_ids.tolist() |
| 903 | + valid_sampled_token_ids = [ |
| 904 | + token_id for token_id in selected_token_ids |
| 905 | + ] |
902 | 906 | else: |
903 | 907 | valid_sampled_token_ids = self.rejection_sampler.parse_output( |
904 | 908 | next_tokens, self.input_batch.vocab_size, |
@@ -975,10 +979,17 @@ def select_local_fn(local_array, local_indices): |
975 | 979 | return ret |
976 | 980 |
|
977 | 981 | @staticmethod |
978 | | - @functools.partial(jax.jit, static_argnames=("max_logprobs", )) |
979 | | - def _compute_and_gather_logprobs(logits, next_tokens, max_logprobs): |
980 | | - logprobs = compute_logprobs(logits) |
981 | | - return gather_logprobs(logprobs, next_tokens, max_logprobs) |
| 982 | + def _compute_and_gather_logprobs(logits, next_tokens, |
| 983 | + max_logprobs) -> LogprobsTensors: |
| 984 | + |
| 985 | + @functools.partial(jax.jit, static_argnames=("max_logprobs", )) |
| 986 | + def jit_compute_and_gather_logprobs(logits, next_tokens, max_logprobs): |
| 987 | + logprobs = compute_logprobs(logits) |
| 988 | + return gather_logprobs(logprobs, next_tokens, max_logprobs) |
| 989 | + |
| 990 | + logprobs = jit_compute_and_gather_logprobs(logits, next_tokens, |
| 991 | + max_logprobs) |
| 992 | + return jax.tree.map(lambda x: j2t(x.astype(jnp.float32)), logprobs) |
982 | 993 |
|
983 | 994 | def _prepare_dp_input_metadata(self, |
984 | 995 | scheduler_output: "VllmSchedulerOutput"): |
|
0 commit comments