Skip to content

Commit 953f9f1

Browse files
committed
fix mtp aclgraph error
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
1 parent 834e2f2 commit 953f9f1

File tree

2 files changed

+5
-13
lines changed

2 files changed

+5
-13
lines changed

vllm_ascend/spec_decode/mtp_proposer.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch.nn.functional as F
88
from vllm.config import (CUDAGraphMode, VllmConfig,
99
get_layers_from_vllm_config, set_current_vllm_config)
10-
from vllm.forward_context import BatchDescriptor, get_forward_context
10+
from vllm.forward_context import get_forward_context
1111
from vllm.logger import init_logger
1212
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
1313
from vllm.model_executor.model_loader import get_model_loader
@@ -693,13 +693,11 @@ def _propose(
693693
2))) and (scheduler_output.total_num_scheduled_tokens
694694
== self.runner.input_batch.num_reqs *
695695
(self.num_speculative_tokens + 1))
696-
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
697-
uniform_decode=uniform_decode)
698696
else:
699-
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
700-
uniform_decode=False)
697+
uniform_decode = False
698+
has_lora = len(self.runner.input_batch.lora_id_to_lora_request) > 0
701699
aclgraph_runtime_mode, batch_descriptor = \
702-
self.runner.aclgraph_dispatcher.dispatch(batch_descriptor)
700+
self.runner.aclgraph_dispatcher.dispatch(num_tokens=num_input_tokens, uniform_decode=uniform_decode, has_lora=has_lora)
703701

704702
if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs(
705703
) and aclgraph_runtime_mode == CUDAGraphMode.FULL:

vllm_ascend/torchair/torchair_mtp_proposer.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from torchair import patch_for_hcom
77
from vllm.config import (CUDAGraphMode, VllmConfig,
88
get_layers_from_vllm_config, set_current_vllm_config)
9-
from vllm.forward_context import BatchDescriptor, get_forward_context
9+
from vllm.forward_context import get_forward_context
1010
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
1111
from vllm.model_executor.model_loader import get_model_loader
1212
from vllm.model_executor.model_loader.utils import \
@@ -343,12 +343,7 @@ def _propose_torchair(
343343
# torchair mode can reuse self.runner.num_tokens_across_dp
344344
num_tokens_across_dp = self.runner.num_tokens_across_dp
345345
with_prefill = self.runner.with_prefill
346-
347346
moe_comm_type = self.runner._select_moe_comm_method(num_input_tokens)
348-
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
349-
uniform_decode=False)
350-
aclgraph_runtime_mode, batch_descriptor = \
351-
self.runner.aclgraph_dispatcher.dispatch(batch_descriptor)
352347

353348
for step in range(self.num_speculative_tokens):
354349
with set_ascend_forward_context(
@@ -359,7 +354,6 @@ def _propose_torchair(
359354
num_tokens_across_dp=num_tokens_across_dp,
360355
reserved_mc2_mask=self.runner.reserved_mc2_mask,
361356
moe_comm_type=moe_comm_type,
362-
aclgraph_runtime_mode=aclgraph_runtime_mode,
363357
in_profile_run=self.runner.in_profile_run,
364358
num_actual_tokens=num_tokens):
365359
with ProfileExecuteDuration().capture_async('mtp_forward'):

0 commit comments

Comments
 (0)