Skip to content

Commit 51c8f60

Browse files
authored
[Bugfix] Resolve MTP > 1 issue when lm head tp > 1 (#4254)
### What this PR does / why we need it? Previously, the dummy run executed compute_logits only once, regardless of num_speculative_tokens. This caused execute_model to hang on compute_logits when lm head tensor parallelism exceeded 1. The fix ensures compute_logits executes correctly during dummy run, matching num_speculative_tokens. I set the `non_blocking` argument to False when moving `exceeds_max_model_len` to the CPU. From what I understand, using `non_blocking=True` and immediately accessing the tensor on the CPU can cause accuracy problems. However, this issue doesn't happen when transferring data to a device. ref: https://discuss.pytorch.org/t/should-we-set-non-blocking-to-true/38234/18 - vLLM version: v0.11.0 - vLLM main: vllm-project/vllm@2918c1b --------- Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
1 parent e8e20c0 commit 51c8f60

File tree

5 files changed

+31
-19
lines changed

5 files changed

+31
-19
lines changed

vllm_ascend/spec_decode/eagle_proposer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,8 @@ def dummy_run(self,
123123
num_reqs: int = 0,
124124
num_tokens_across_dp: Optional[torch.Tensor] = None,
125125
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
126-
batch_descriptor=None):
126+
batch_descriptor=None,
127+
dummy_compute_logits=lambda hidden_states: None):
127128
moe_comm_type = self.runner._select_moe_comm_method(num_tokens)
128129
with set_ascend_forward_context(None,
129130
self.vllm_config,
@@ -134,6 +135,7 @@ def dummy_run(self,
134135
positions=self.positions[:num_tokens],
135136
hidden_states=self.hidden_states[:num_tokens],
136137
)
138+
dummy_compute_logits(self.hidden_states)
137139

138140
def generate_token_ids(self,
139141
valid_sampled_token_ids: list[np.ndarray],

vllm_ascend/spec_decode/mtp_proposer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,8 @@ def dummy_run(self,
213213
num_reqs: int = 0,
214214
num_tokens_across_dp=None,
215215
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
216-
batch_descriptor=None) -> None:
216+
batch_descriptor=None,
217+
dummy_compute_logits=lambda hidden_states: None) -> None:
217218

218219
(
219220
num_tokens,
@@ -296,6 +297,7 @@ def dummy_run(self,
296297
self.update_stream, forward_context,
297298
positions.shape[0],
298299
self.vllm_config.speculative_config)
300+
dummy_compute_logits(previous_hidden_states)
299301
if with_prefill:
300302
break
301303

@@ -756,6 +758,7 @@ def _propose(
756758
logits = self.model.compute_logits(sample_hidden_states)
757759
if lmhead_tp_enable() and num_indices < logits.shape[0]:
758760
logits = logits[:num_indices]
761+
last_token_indices = last_token_indices[:num_indices]
759762
draft_token_ids = logits.argmax(dim=-1)
760763

761764
if self.num_speculative_tokens == 1:
@@ -821,7 +824,7 @@ def _propose(
821824
# For the requests that exceed the max model length, we set the
822825
# sequence length to 1 to minimize their overheads in attention.
823826
exceeds_max_model_len_cpu = exceeds_max_model_len.to(
824-
attn_metadata_i.seq_lens.device, non_blocking=True)
827+
attn_metadata_i.seq_lens.device, non_blocking=False)
825828
attn_metadata_i.seq_lens[:batch_size].masked_fill_(
826829
exceeds_max_model_len_cpu, 1)
827830
# Mask out the slot mappings that exceed the max model length.

vllm_ascend/spec_decode/ngram_proposer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ def dummy_run(self,
2727
num_reqs=None,
2828
num_tokens_across_dp=None,
2929
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
30-
batch_descriptor=None):
30+
batch_descriptor=None,
31+
dummy_compute_logits=lambda hidden_states: None):
3132
pass
3233

3334
def generate_token_ids(self,

vllm_ascend/torchair/torchair_mtp_proposer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,8 @@ def dummy_run(self,
8181
num_reqs: int = 0,
8282
num_tokens_across_dp=None,
8383
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
84-
batch_descriptor=None) -> None:
84+
batch_descriptor=None,
85+
dummy_compute_logits=lambda hidden_states: None) -> None:
8586
moe_comm_type = self.runner._select_moe_comm_method(num_tokens)
8687

8788
if not with_prefill:
@@ -143,6 +144,7 @@ def dummy_run(self,
143144
self.model(input_ids=input_ids,
144145
positions=positions,
145146
hidden_states=previous_hidden_states)
147+
dummy_compute_logits(previous_hidden_states)
146148
if with_prefill:
147149
break
148150

vllm_ascend/worker/model_runner_v1.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3003,14 +3003,21 @@ def _dummy_run(
30033003

30043004
need_dummy_logits = (not self.in_profile_run
30053005
and lmhead_tp_enable())
3006-
3007-
if need_dummy_logits:
3008-
max_num_reqs_across_dp = num_tokens if not with_prefill else max_num_reqs
3009-
dummy_indices = torch.zeros(max_num_reqs_across_dp,
3010-
dtype=torch.int32)
3011-
3012-
def dummy_compute_logits(hidden_states):
3013-
return self.model.compute_logits(
3006+
max_num_reqs_across_dp = num_tokens if not with_prefill else max_num_reqs
3007+
dummy_indices = torch.zeros(max_num_reqs_across_dp,
3008+
dtype=torch.int32)
3009+
3010+
def dummy_compute_logits(hidden_states):
3011+
if not need_dummy_logits:
3012+
return None
3013+
return self.model.compute_logits(hidden_states[dummy_indices])
3014+
3015+
def dummy_drafter_compute_logits(hidden_states):
3016+
if not need_dummy_logits or self.drafter is None:
3017+
return
3018+
if hasattr(self.drafter, "model") and hasattr(
3019+
self.drafter.model, "compute_logits"):
3020+
return self.drafter.model.compute_logits(
30143021
hidden_states[dummy_indices])
30153022

30163023
with set_ascend_forward_context(
@@ -3032,8 +3039,7 @@ def dummy_compute_logits(hidden_states):
30323039
with_prefill, is_torchair_compile, input_ids, positions,
30333040
attn_metadata, num_tokens, intermediate_tensors,
30343041
inputs_embeds)
3035-
if need_dummy_logits:
3036-
dummy_compute_logits(hidden_states)
3042+
dummy_compute_logits(hidden_states)
30373043

30383044
if self.drafter:
30393045
self.drafter.dummy_run(
@@ -3042,10 +3048,8 @@ def dummy_compute_logits(hidden_states):
30423048
num_reqs=num_reqs,
30433049
num_tokens_across_dp=num_tokens_across_dp,
30443050
aclgraph_runtime_mode=aclgraph_runtime_mode,
3045-
batch_descriptor=batch_descriptor)
3046-
if need_dummy_logits:
3047-
self.drafter.model.compute_logits(
3048-
hidden_states[dummy_indices])
3051+
batch_descriptor=batch_descriptor,
3052+
dummy_compute_logits=dummy_drafter_compute_logits)
30493053
if self.in_profile_run and self.dynamic_eplb:
30503054
self.model.clear_all_moe_loads()
30513055
if not self.in_profile_run and self.dynamic_eplb:

0 commit comments

Comments
 (0)