From def02d17e247f81d79f0ec09ba9228209dcfb623 Mon Sep 17 00:00:00 2001 From: 17764591921 Date: Sat, 29 Nov 2025 11:07:34 +0800 Subject: [PATCH] [Bugfix] Fix the Eagle3 inference failure issue Signed-off-by: sunchendd --- vllm_ascend/attention/attention_mask.py | 31 ++++++++++++++++++++++- vllm_ascend/spec_decode/eagle_proposer.py | 6 ++--- vllm_ascend/worker/model_runner_v1.py | 11 ++++---- 3 files changed, 38 insertions(+), 10 deletions(-) diff --git a/vllm_ascend/attention/attention_mask.py b/vllm_ascend/attention/attention_mask.py index 3514984d826..1d4a231fc11 100644 --- a/vllm_ascend/attention/attention_mask.py +++ b/vllm_ascend/attention/attention_mask.py @@ -88,7 +88,36 @@ def get_splitfuse_attn_mask( dtype: torch.dtype = None, device: torch.device = None, ) -> torch.Tensor: - return self.chunked_prefill_attn_mask + cann_version = getattr(torch.version, "cann", "") + target_device = device or self.device + use_chunked_mask = (seq_lens is None or position is None + or dtype is None or cann_version.startswith("8.3")) + + if use_chunked_mask: + if target_device is None: + raise ValueError( + "splitfuse_attn_mask requires device when using chunked mask" + ) + + return self.chunked_prefill_attn_mask.to(target_device, + non_blocking=True) + + if dtype not in [torch.float16, torch.bfloat16]: + raise ValueError( + "splitfuse_attn_mask now only supports bf16 and fp16") + if target_device is None: + raise ValueError( + "splitfuse_attn_mask requires device for non-chunked mask") + max_seq_len = seq_lens.max().item() if seq_lens.numel() > 0 else 0 + self._update_attn_cache(max_seq_len, dtype) + # FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation + # is not the same. Fix this in the future when kernel is ready. + mask_scale_factor = AttentionMaskBuilder.get_mask_scale_factor(dtype) + attn_mask = torch.index_select(self.attn_mask_cache, + dim=0, + index=position)[:, :max_seq_len] + attn_mask *= mask_scale_factor + return attn_mask.contiguous().to(target_device, non_blocking=True) def _update_attn_cache(self, seqlen: int, dtype: torch.dtype): if seqlen > self._seq_len_cached: diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index 75f01ee9bdb..9987fd6ced3 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -79,7 +79,7 @@ def __init__(self, dtype=torch.int32) attn_mask_len = self.vllm_config.model_config.max_model_len self.attn_mask_builder = AttentionMaskBuilder( - attn_mask_len, self.vllm_config.model_config.dtype) + attn_mask_len, self.vllm_config.model_config.dtype, device=device) def load_model(self, model: nn.Module) -> None: target_attn_layer_names = set( @@ -427,9 +427,7 @@ def _propose( query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1] max_query_len = query_lens.max().item() - attn_mask = self.attn_mask_builder.get_splitfuse_attn_mask( - seq_lens, target_positions, self.vllm_config.model_config.dtype, - self.device) + attn_mask = self.runner.attn_mask common_attn_metadata = AscendCommonAttentionMetadata( query_start_loc=cu_num_tokens.to(device), diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index ff55d1d1897..79f868da94f 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -991,8 +991,8 @@ def _make_attention_mask(self, seq_lens, position, max_seq_len, self.dtype, self.device) # Prefill with cache hit. elif attn_state == AscendAttentionState.PrefillCacheHit: - return self.attn_mask_builder.get_attn_mask( - 2048, self.dtype, self.device) + return self.attn_mask_builder.get_splitfuse_attn_mask().to( + torch.bool) # Decode-only situation. else: return None @@ -1954,10 +1954,11 @@ def _build_attn_state(self, num_reqs, num_scheduled_tokens, attn_state = AscendAttentionState.SpecDecoding # Speculative decoding. elif np.all(num_valid_tokens == 1): - if self.speculative_config and self.speculative_config.method == 'deepseek_mtp': - attn_state = AscendAttentionState.SpecDecoding - else: + if self.drafter and self.drafter.name in (SpecDcodeType.EAGLE, + SpecDcodeType.EAGLE3): attn_state = AscendAttentionState.ChunkedPrefill + else: + attn_state = AscendAttentionState.SpecDecoding # splitfuse elif not ascend_config.ascend_scheduler_config.enabled or self.chunked_prefill_enabled: attn_state = AscendAttentionState.ChunkedPrefill