Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion vllm_ascend/attention/attention_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,36 @@ def get_splitfuse_attn_mask(
dtype: torch.dtype = None,
device: torch.device = None,
) -> torch.Tensor:
return self.chunked_prefill_attn_mask
cann_version = getattr(torch.version, "cann", "")
target_device = device or self.device
use_chunked_mask = (seq_lens is None or position is None
or dtype is None or cann_version.startswith("8.3"))

if use_chunked_mask:
if target_device is None:
raise ValueError(
"splitfuse_attn_mask requires device when using chunked mask"
)

return self.chunked_prefill_attn_mask.to(target_device,
non_blocking=True)

if dtype not in [torch.float16, torch.bfloat16]:
raise ValueError(
"splitfuse_attn_mask now only supports bf16 and fp16")
if target_device is None:
raise ValueError(
"splitfuse_attn_mask requires device for non-chunked mask")
max_seq_len = seq_lens.max().item() if seq_lens.numel() > 0 else 0
self._update_attn_cache(max_seq_len, dtype)
# FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation
# is not the same. Fix this in the future when kernel is ready.
mask_scale_factor = AttentionMaskBuilder.get_mask_scale_factor(dtype)
attn_mask = torch.index_select(self.attn_mask_cache,
dim=0,
index=position)[:, :max_seq_len]
attn_mask *= mask_scale_factor
return attn_mask.contiguous().to(target_device, non_blocking=True)

def _update_attn_cache(self, seqlen: int, dtype: torch.dtype):
if seqlen > self._seq_len_cached:
Expand Down
6 changes: 2 additions & 4 deletions vllm_ascend/spec_decode/eagle_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __init__(self,
dtype=torch.int32)
attn_mask_len = self.vllm_config.model_config.max_model_len
self.attn_mask_builder = AttentionMaskBuilder(
attn_mask_len, self.vllm_config.model_config.dtype)
attn_mask_len, self.vllm_config.model_config.dtype, device=device)

def load_model(self, model: nn.Module) -> None:
target_attn_layer_names = set(
Expand Down Expand Up @@ -427,9 +427,7 @@ def _propose(

query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1]
max_query_len = query_lens.max().item()
attn_mask = self.attn_mask_builder.get_splitfuse_attn_mask(
seq_lens, target_positions, self.vllm_config.model_config.dtype,
self.device)
attn_mask = self.runner.attn_mask

common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=cu_num_tokens.to(device),
Expand Down
11 changes: 6 additions & 5 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -991,8 +991,8 @@ def _make_attention_mask(self, seq_lens, position,
max_seq_len, self.dtype, self.device)
# Prefill with cache hit.
elif attn_state == AscendAttentionState.PrefillCacheHit:
return self.attn_mask_builder.get_attn_mask(
2048, self.dtype, self.device)
return self.attn_mask_builder.get_splitfuse_attn_mask().to(
torch.bool)
# Decode-only situation.
else:
return None
Expand Down Expand Up @@ -1954,10 +1954,11 @@ def _build_attn_state(self, num_reqs, num_scheduled_tokens,
attn_state = AscendAttentionState.SpecDecoding
# Speculative decoding.
elif np.all(num_valid_tokens == 1):
if self.speculative_config and self.speculative_config.method == 'deepseek_mtp':
attn_state = AscendAttentionState.SpecDecoding
else:
if self.drafter and self.drafter.name in (SpecDcodeType.EAGLE,
SpecDcodeType.EAGLE3):
attn_state = AscendAttentionState.ChunkedPrefill
else:
attn_state = AscendAttentionState.SpecDecoding
Comment on lines +1957 to +1961
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The logic to determine the attention state for Eagle3 speculative decoding appears to be incorrect. Currently, it sets attn_state to AscendAttentionState.ChunkedPrefill for Eagle and Eagle3, and AscendAttentionState.SpecDecoding for other drafters. However, ChunkedPrefill is typically used for prefill stages, not for speculative decoding which happens after prefill. For speculative decoding, AscendAttentionState.SpecDecoding should be used to ensure the correct attention mechanism is applied. Using ChunkedPrefill here could lead to incorrect attention masks and potential failures during the decoding phase.

Suggested change
if self.drafter and self.drafter.name in (SpecDcodeType.EAGLE,
SpecDcodeType.EAGLE3):
attn_state = AscendAttentionState.ChunkedPrefill
else:
attn_state = AscendAttentionState.SpecDecoding
if self.drafter and self.drafter.name in (SpecDcodeType.EAGLE,
SpecDcodeType.EAGLE3):
attn_state = AscendAttentionState.SpecDecoding
else:
attn_state = AscendAttentionState.SpecDecoding

# splitfuse
elif not ascend_config.ascend_scheduler_config.enabled or self.chunked_prefill_enabled:
attn_state = AscendAttentionState.ChunkedPrefill
Expand Down
Loading