Skip to content

Commit 5f42fc5

Browse files
authored
[backends][short_conv] CUDA graph piecewise edits (vllm-project#24215)
Signed-off-by: Paul Pak <paulpak58@gmail.com>
1 parent 8ee846c commit 5f42fc5

File tree

2 files changed

+21
-21
lines changed

2 files changed

+21
-21
lines changed

vllm/model_executor/layers/mamba/short_conv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def forward_cuda(
115115
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
116116
conv_state = self_kv_cache[0].transpose(-1, -2)
117117
state_indices_tensor = attn_metadata.state_indices_tensor
118-
has_initial_states_p = attn_metadata.has_initial_states
118+
has_initial_states_p = attn_metadata.has_initial_states_p
119119

120120
BCx, _ = self.in_proj(hidden_states)
121121

vllm/v1/attention/backends/short_conv_attn.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66
import torch
77

88
from vllm.attention.backends.abstract import AttentionBackend
9-
from vllm.config import VllmConfig
10-
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
9+
from vllm.v1.attention.backends.mamba_attn import (
10+
BaseMambaAttentionMetadataBuilder)
11+
from vllm.v1.attention.backends.utils import (PAD_SLOT_ID,
1112
CommonAttentionMetadata,
1213
compute_causal_conv1d_metadata,
1314
split_decodes_and_prefills)
14-
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
1515

1616

1717
class ShortConvAttentionBackend(AttentionBackend):
@@ -29,8 +29,8 @@ class ShortConvAttentionMetadata:
2929
num_decode_tokens: int
3030

3131
query_start_loc: torch.Tensor
32-
has_initial_states: torch.Tensor
33-
state_indices_tensor: torch.Tensor # shape: [batch,]
32+
state_indices_tensor: torch.Tensor
33+
has_initial_states_p: Optional[torch.Tensor]
3434

3535
# For causal_conv1d
3636
nums_dict: Optional[dict] = None
@@ -39,22 +39,14 @@ class ShortConvAttentionMetadata:
3939

4040

4141
class ShortConvAttentionMetadataBuilder(
42-
AttentionMetadataBuilder[ShortConvAttentionMetadata]):
43-
44-
reorder_batch_threshold: int = 1
45-
46-
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
47-
vllm_config: VllmConfig, device: torch.device):
48-
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
49-
assert isinstance(kv_cache_spec, MambaSpec)
42+
BaseMambaAttentionMetadataBuilder[ShortConvAttentionMetadata]):
5043

5144
def build(self,
5245
common_prefix_len: int,
5346
common_attn_metadata: CommonAttentionMetadata,
5447
fast_build: bool = False) -> ShortConvAttentionMetadata:
5548
num_reqs = common_attn_metadata.num_reqs
5649
query_start_loc = common_attn_metadata.query_start_loc
57-
5850
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
5951

6052
# for causal_conv1d
@@ -64,13 +56,13 @@ def build(self,
6456
split_decodes_and_prefills(
6557
common_attn_metadata,
6658
decode_threshold=self.reorder_batch_threshold))
67-
has_initial_states = None
59+
60+
has_initial_states_p = None
6861
if num_prefills > 0:
69-
#[batch,]
7062
has_initial_states_cpu = (
7163
common_attn_metadata.
7264
num_computed_tokens_cpu[num_reqs - num_prefills:num_reqs] > 0)
73-
has_initial_states = has_initial_states_cpu.to(
65+
has_initial_states_p = has_initial_states_cpu.to(
7466
query_start_loc.device)
7567

7668
query_start_loc_p = common_attn_metadata.query_start_loc[
@@ -79,14 +71,22 @@ def build(self,
7971
nums_dict, batch_ptr, token_chunk_offset_ptr = \
8072
compute_causal_conv1d_metadata(query_start_loc_p)
8173

74+
elif (num_decodes > 0 and num_decodes <= self.decode_cudagraph_max_bs
75+
and self.compilation_config.full_cuda_graph):
76+
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_decodes)
77+
self.state_indices_tensor[:num_decodes].copy_(state_indices_tensor,
78+
non_blocking=True)
79+
state_indices_tensor = self.state_indices_tensor[:num_input_tokens]
80+
state_indices_tensor[num_decodes:] = PAD_SLOT_ID
81+
8282
attn_metadata = ShortConvAttentionMetadata(
83+
query_start_loc=query_start_loc,
84+
state_indices_tensor=state_indices_tensor,
85+
has_initial_states_p=has_initial_states_p,
8386
num_prefills=num_prefills,
8487
num_prefill_tokens=num_prefill_tokens,
8588
num_decodes=num_decodes,
8689
num_decode_tokens=num_decode_tokens,
87-
query_start_loc=query_start_loc,
88-
has_initial_states=has_initial_states,
89-
state_indices_tensor=state_indices_tensor,
9090
nums_dict=nums_dict,
9191
batch_ptr=batch_ptr,
9292
token_chunk_offset_ptr=token_chunk_offset_ptr,

0 commit comments

Comments
 (0)