Skip to content

Commit 2b4f7a5

Browse files
authored
[cherry-pick pr-4254] bugfix for mtp>1 when lm_head_tp>1 (#4360)
### What this PR does / why we need it? Previously, the dummy run executed compute_logits only once, regardless of num_speculative_tokens. This caused execute_model to hang on compute_logits when lm head tensor parallelism exceeded 1. The fix ensures compute_logits executes correctly during dummy run, matching num_speculative_tokens. Signed-off-by: zouyida2052 <zouyida2002@gmail.com>
1 parent cd9f5c0 commit 2b4f7a5

File tree

3 files changed

+26
-16
lines changed

3 files changed

+26
-16
lines changed

vllm_ascend/spec_decode/eagle_proposer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,8 @@ def dummy_run(self,
116116
num_reqs: int = 0,
117117
num_tokens_across_dp: Optional[torch.Tensor] = None,
118118
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
119-
batch_descriptor=None):
119+
batch_descriptor=None,
120+
dummy_compute_logits=lambda hidden_states: None):
120121
moe_comm_type = self.runner._select_moe_comm_method(
121122
num_tokens, with_prefill)
122123
with set_ascend_forward_context(None,
@@ -128,6 +129,7 @@ def dummy_run(self,
128129
positions=self.positions[:num_tokens],
129130
hidden_states=self.hidden_states[:num_tokens],
130131
)
132+
dummy_compute_logits(self.hidden_states)
131133

132134
def generate_token_ids(self,
133135
valid_sampled_token_ids: list[list[int]],

vllm_ascend/spec_decode/mtp_proposer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,8 @@ def dummy_run(self,
114114
num_reqs: int = 0,
115115
num_tokens_across_dp=None,
116116
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
117-
batch_descriptor=None) -> None:
117+
batch_descriptor=None,
118+
dummy_compute_logits=lambda hidden_states: None) -> None:
118119
if not self.torchair_graph_enabled:
119120
# TODO: adapt enable_dbo later
120121
(num_tokens, num_tokens_across_dp, with_prefill,
@@ -188,6 +189,7 @@ def dummy_run(self,
188189
self.model(input_ids=input_ids,
189190
positions=positions,
190191
hidden_states=previous_hidden_states)
192+
dummy_compute_logits(previous_hidden_states)
191193
if with_prefill:
192194
break
193195

@@ -490,6 +492,7 @@ def _propose(
490492
logits = self.model.compute_logits(sample_hidden_states)
491493
if lmhead_tp_enable() and num_indices < logits.shape[0]:
492494
logits = logits[:num_indices]
495+
last_token_indices = last_token_indices[:num_indices]
493496
draft_token_ids = logits.argmax(dim=-1)
494497

495498
if self.num_speculative_tokens == 1:
@@ -554,7 +557,7 @@ def _propose(
554557
# For the requests that exceed the max model length, we set the
555558
# sequence length to 1 to minimize their overheads in attention.
556559
exceeds_max_model_len_cpu = exceeds_max_model_len.to(
557-
attn_metadata_i.seq_lens.device, non_blocking=True)
560+
attn_metadata_i.seq_lens.device, non_blocking=False)
558561
attn_metadata_i.seq_lens[:batch_size].masked_fill_(
559562
exceeds_max_model_len_cpu, 1)
560563
# Mask out the slot mappings that exceed the max model length.

vllm_ascend/worker/model_runner_v1.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2465,13 +2465,21 @@ def _dummy_run(
24652465
need_dummy_logits = (not self.in_profile_run
24662466
and lmhead_tp_enable())
24672467

2468-
if need_dummy_logits:
2469-
max_num_reqs_across_dp = num_tokens if not with_prefill else max_num_reqs
2470-
dummy_indices = torch.zeros(max_num_reqs_across_dp,
2471-
dtype=torch.int32)
2472-
2473-
def dummy_compute_logits(hidden_states):
2474-
return self.model.compute_logits(
2468+
max_num_reqs_across_dp = num_tokens if not with_prefill else max_num_reqs
2469+
dummy_indices = torch.zeros(max_num_reqs_across_dp,
2470+
dtype=torch.int32)
2471+
2472+
def dummy_compute_logits(hidden_states):
2473+
if not need_dummy_logits:
2474+
return None
2475+
return self.model.compute_logits(hidden_states[dummy_indices])
2476+
2477+
def dummy_drafter_compute_logits(hidden_states):
2478+
if not need_dummy_logits or self.drafter is None:
2479+
return
2480+
if hasattr(self.drafter, "model") and hasattr(
2481+
self.drafter.model, "compute_logits"):
2482+
return self.drafter.model.compute_logits(
24752483
hidden_states[dummy_indices])
24762484

24772485
with set_ascend_forward_context(
@@ -2493,8 +2501,7 @@ def dummy_compute_logits(hidden_states):
24932501
with_prefill, is_torchair_compile, input_ids, positions,
24942502
attn_metadata, num_tokens, intermediate_tensors,
24952503
inputs_embeds)
2496-
if need_dummy_logits:
2497-
dummy_compute_logits(hidden_states)
2504+
dummy_compute_logits(hidden_states)
24982505

24992506
if self.drafter:
25002507
self.drafter.dummy_run(
@@ -2504,10 +2511,8 @@ def dummy_compute_logits(hidden_states):
25042511
num_reqs=num_reqs,
25052512
num_tokens_across_dp=num_tokens_across_dp,
25062513
aclgraph_runtime_mode=aclgraph_runtime_mode,
2507-
batch_descriptor=batch_descriptor)
2508-
if need_dummy_logits:
2509-
self.drafter.model.compute_logits(
2510-
hidden_states[dummy_indices])
2514+
batch_descriptor=batch_descriptor,
2515+
dummy_compute_logits=dummy_drafter_compute_logits)
25112516
if self.in_profile_run and self.dynamic_eplb:
25122517
self.model.clear_all_moe_loads()
25132518
if not self.in_profile_run and self.dynamic_eplb:

0 commit comments

Comments
 (0)