Skip to content

Commit 1271416

Browse files
committed
rm vanilla attn
Signed-off-by: zzzzwwjj <1183291235@qq.com>
1 parent 2b82320 commit 1271416

File tree

3 files changed

+0
-192
lines changed

3 files changed

+0
-192
lines changed

tests/ut/attention/test_attention_v1.py

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -454,43 +454,6 @@ def test_forward_decode_only_swa_seq_len_mismatch(
454454

455455
assert output.shape == (10, 8 * 64)
456456

457-
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
458-
@patch('vllm_ascend.utils.get_ascend_device_type',
459-
return_value=AscendDeviceType._910_93)
460-
@patch('torch_npu._npu_reshape_and_cache')
461-
@patch('vllm_ascend.attention.attention_v1.vanilla_chunked_prefill')
462-
def test_forward_head_size_192(self, mock_vanilla_prefill,
463-
mock_npu_reshape_and_cache,
464-
mock_soc_version, mock_get_forward_context):
465-
"""Test forward pass when head_size is 192"""
466-
467-
self.impl.head_size = 192
468-
query = torch.randn(10, 8 * 192)
469-
key = torch.randn(10, 8 * 192)
470-
value = torch.randn(10, 8 * 192)
471-
kv_cache = torch.empty(2, 5, 128, 8, 192)
472-
output = torch.empty_like(query)
473-
474-
mock_get_forward_context.return_value = MagicMock(capturing=False)
475-
476-
metadata = self.attn_metadata
477-
metadata.attn_mask = torch.randn(1, 1, 10, 10)
478-
metadata.query_lens = torch.tensor([10])
479-
metadata.seq_lens = torch.tensor([10])
480-
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
481-
metadata.num_actual_tokens = 10
482-
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
483-
metadata.num_decodes = 10
484-
metadata.num_prefills = 0
485-
layer = self.layer_no_quant
486-
mock_vanilla_prefill.return_value = MagicMock()
487-
488-
output = self.impl_192.forward(layer, query, key, value, kv_cache,
489-
metadata, output)
490-
491-
mock_vanilla_prefill.assert_called_once()
492-
assert output.shape == (10, 8 * 192)
493-
494457
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
495458
@patch('torch_npu.npu_fused_infer_attention_score')
496459
@patch('torch_npu._npu_reshape_and_cache')

vllm_ascend/attention/attention_v1.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
split_decodes_and_prefills)
4242
from vllm_ascend.compilation.acl_graph import (get_graph_params,
4343
update_graph_params_workspaces)
44-
from vllm_ascend.ops.attention import vanilla_chunked_prefill
4544
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendDeviceType,
4645
aligned_16, get_ascend_device_type, nd_to_nz_2d,
4746
nd_to_nz_spec, prefill_context_parallel_enable,
@@ -833,26 +832,6 @@ def _forward_v1_style(
833832
attn_metadata: AscendMetadata,
834833
output: Optional[torch.Tensor] = None,
835834
) -> torch.Tensor:
836-
# Use chunked prefill for head size 192 scenario, like deepseek
837-
# paged_attention_splitfuse maybe crash at such scenario.
838-
# TODO: vanilla path will be removed after the kernel support
839-
# head_size 192 scenario.
840-
if self.head_size == 192:
841-
cu_seqlen_q = [0] + attn_metadata.query_lens.tolist()
842-
cu_seqlen_k = [0] + attn_metadata.seq_lens.tolist()
843-
cu_seqlen_q = torch.tensor(cu_seqlen_q, device=query.device)
844-
cu_seqlen_k = torch.tensor(cu_seqlen_k, device=query.device)
845-
cu_seqlen_q = torch.cumsum(cu_seqlen_q, dim=0)
846-
cu_seqlen_k = torch.cumsum(cu_seqlen_k, dim=0)
847-
max_seqlen_q = torch.max(attn_metadata.query_lens)
848-
max_seqlen_k = torch.max(attn_metadata.seq_lens)
849-
vanilla_chunked_prefill(output, query, self.key_cache,
850-
self.value_cache,
851-
attn_metadata.block_tables, cu_seqlen_q,
852-
cu_seqlen_k, max_seqlen_q, max_seqlen_k,
853-
self.scale, None, True)
854-
return output
855-
856835
# Use paged attention.
857836
assert attn_metadata is not None
858837
assert attn_metadata.attn_mask is not None

vllm_ascend/ops/attention.py

Lines changed: 0 additions & 134 deletions
This file was deleted.

0 commit comments

Comments
 (0)