Skip to content

Commit 9af3475

Browse files
authored
[Bugfix] Fix model run _npu_flash_attention hang issue (#4410)
Fix model run _npu_flash_attention in _forward_prefill_no_cache hang issue, it was caused by wrong attention mask dtype. ### How was this patch tested? Yes, tesed on Qwen2.5-VL and Qwen2.5-Omni - vLLM version: v0.11.0 - vLLM main: vllm-project/vllm@2918c1b Signed-off-by: Ting FU <futing10@huawei.com>
1 parent 048d350 commit 9af3475

File tree

3 files changed

+6
-7
lines changed

3 files changed

+6
-7
lines changed

tests/ut/attention/test_attention_mask.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,11 @@ def test_get_attn_mask(self):
7474
attn_mask = attention_mask_builder.get_attn_mask(
7575
max_seq_len=2048, dtype=torch.float16, device=torch.device("cpu"))
7676
self.assertEqual(attn_mask.shape, (2048, 2048))
77-
self.assertEqual(attn_mask[0][-1], torch.tensor(True))
78-
self.assertEqual(attention_mask_builder._seq_len_cached, 1024)
77+
self.assertEqual(attn_mask[0][-1],
78+
torch.tensor(float("-inf"), dtype=torch.float16))
79+
self.assertEqual(attention_mask_builder._seq_len_cached, 2048)
7980
self.assertEqual(attention_mask_builder.attn_mask_cache.shape,
80-
(1024, 1024))
81+
(2048, 2048))
8182
self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1],
8283
torch.tensor(float("-inf"), dtype=torch.float16))
8384

vllm_ascend/attention/attention_mask.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,6 @@ def get_mask_scale_factor(dtype: torch.dtype = torch.float16):
6767

6868
def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype,
6969
device: torch.device):
70-
if max_seq_len == 2048:
71-
return self.chunked_prefill_attn_mask.to(torch.bool)
7270
self._update_attn_cache(max_seq_len, dtype)
7371
return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous(
7472
).to(device, non_blocking=True)

vllm_ascend/worker/model_runner_v1.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -991,8 +991,8 @@ def _make_attention_mask(self, seq_lens, position,
991991
max_seq_len, self.dtype, self.device)
992992
# Prefill with cache hit.
993993
elif attn_state == AscendAttentionState.PrefillCacheHit:
994-
return self.attn_mask_builder.get_attn_mask(
995-
2048, self.dtype, self.device)
994+
return self.attn_mask_builder.get_splitfuse_attn_mask().to(
995+
torch.bool)
996996
# Decode-only situation.
997997
else:
998998
return None

0 commit comments

Comments
 (0)