66import torch
77
88from vllm .attention .backends .abstract import AttentionBackend
9- from vllm .config import VllmConfig
10- from vllm .v1 .attention .backends .utils import (AttentionMetadataBuilder ,
9+ from vllm .v1 .attention .backends .mamba_attn import (
10+ BaseMambaAttentionMetadataBuilder )
11+ from vllm .v1 .attention .backends .utils import (PAD_SLOT_ID ,
1112 CommonAttentionMetadata ,
1213 compute_causal_conv1d_metadata ,
1314 split_decodes_and_prefills )
14- from vllm .v1 .kv_cache_interface import AttentionSpec , MambaSpec
1515
1616
1717class ShortConvAttentionBackend (AttentionBackend ):
@@ -29,8 +29,8 @@ class ShortConvAttentionMetadata:
2929 num_decode_tokens : int
3030
3131 query_start_loc : torch .Tensor
32- has_initial_states : torch .Tensor
33- state_indices_tensor : torch .Tensor # shape: [batch, ]
32+ state_indices_tensor : torch .Tensor
33+ has_initial_states_p : Optional [ torch .Tensor ]
3434
3535 # For causal_conv1d
3636 nums_dict : Optional [dict ] = None
@@ -39,22 +39,14 @@ class ShortConvAttentionMetadata:
3939
4040
4141class ShortConvAttentionMetadataBuilder (
42- AttentionMetadataBuilder [ShortConvAttentionMetadata ]):
43-
44- reorder_batch_threshold : int = 1
45-
46- def __init__ (self , kv_cache_spec : AttentionSpec , layer_names : list [str ],
47- vllm_config : VllmConfig , device : torch .device ):
48- super ().__init__ (kv_cache_spec , layer_names , vllm_config , device )
49- assert isinstance (kv_cache_spec , MambaSpec )
42+ BaseMambaAttentionMetadataBuilder [ShortConvAttentionMetadata ]):
5043
5144 def build (self ,
5245 common_prefix_len : int ,
5346 common_attn_metadata : CommonAttentionMetadata ,
5447 fast_build : bool = False ) -> ShortConvAttentionMetadata :
5548 num_reqs = common_attn_metadata .num_reqs
5649 query_start_loc = common_attn_metadata .query_start_loc
57-
5850 state_indices_tensor = common_attn_metadata .block_table_tensor [:, 0 ]
5951
6052 # for causal_conv1d
@@ -64,13 +56,13 @@ def build(self,
6456 split_decodes_and_prefills (
6557 common_attn_metadata ,
6658 decode_threshold = self .reorder_batch_threshold ))
67- has_initial_states = None
59+
60+ has_initial_states_p = None
6861 if num_prefills > 0 :
69- #[batch,]
7062 has_initial_states_cpu = (
7163 common_attn_metadata .
7264 num_computed_tokens_cpu [num_reqs - num_prefills :num_reqs ] > 0 )
73- has_initial_states = has_initial_states_cpu .to (
65+ has_initial_states_p = has_initial_states_cpu .to (
7466 query_start_loc .device )
7567
7668 query_start_loc_p = common_attn_metadata .query_start_loc [
@@ -79,14 +71,22 @@ def build(self,
7971 nums_dict , batch_ptr , token_chunk_offset_ptr = \
8072 compute_causal_conv1d_metadata (query_start_loc_p )
8173
74+ elif (num_decodes > 0 and num_decodes <= self .decode_cudagraph_max_bs
75+ and self .compilation_config .full_cuda_graph ):
76+ num_input_tokens = self .vllm_config .pad_for_cudagraph (num_decodes )
77+ self .state_indices_tensor [:num_decodes ].copy_ (state_indices_tensor ,
78+ non_blocking = True )
79+ state_indices_tensor = self .state_indices_tensor [:num_input_tokens ]
80+ state_indices_tensor [num_decodes :] = PAD_SLOT_ID
81+
8282 attn_metadata = ShortConvAttentionMetadata (
83+ query_start_loc = query_start_loc ,
84+ state_indices_tensor = state_indices_tensor ,
85+ has_initial_states_p = has_initial_states_p ,
8386 num_prefills = num_prefills ,
8487 num_prefill_tokens = num_prefill_tokens ,
8588 num_decodes = num_decodes ,
8689 num_decode_tokens = num_decode_tokens ,
87- query_start_loc = query_start_loc ,
88- has_initial_states = has_initial_states ,
89- state_indices_tensor = state_indices_tensor ,
9090 nums_dict = nums_dict ,
9191 batch_ptr = batch_ptr ,
9292 token_chunk_offset_ptr = token_chunk_offset_ptr ,
0 commit comments