Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions vllm/v1/attention/backends/cpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,11 @@ def build(
num_decode_tokens = 0
if self.use_sdpa_prefill and causal:
# Decoder, need reorder and truncate
assert self.reorder_batch_threshold
assert self._reorder_batch_threshold is not None
(num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) = (
split_decodes_and_prefills(
common_attn_metadata,
decode_threshold=self.reorder_batch_threshold,
decode_threshold=self._reorder_batch_threshold,
require_uniform=True,
)
)
Expand Down
9 changes: 5 additions & 4 deletions vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,8 +404,6 @@ class FlashInferMetadata:


class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
reorder_batch_threshold: int = 1

def __init__(
self,
kv_cache_spec: AttentionSpec,
Expand Down Expand Up @@ -651,10 +649,12 @@ def build(
) -> FlashInferMetadata:
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens

assert self._reorder_batch_threshold is not None
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
common_attn_metadata,
decode_threshold=self.reorder_batch_threshold,
decode_threshold=self._reorder_batch_threshold,
require_uniform=True,
)
)
Expand Down Expand Up @@ -752,7 +752,8 @@ def build(
paged_kv_last_page_len_np,
)

uses_spec_reorder = self.reorder_batch_threshold > 1
uses_spec_reorder = self._reorder_batch_threshold > 1

prefill_use_trtllm = use_trtllm_attention(
self.num_qo_heads,
self.num_kv_heads,
Expand Down
7 changes: 4 additions & 3 deletions vllm/v1/attention/backends/gdn_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,26 +61,27 @@ class GDNAttentionMetadata:
class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]):
_cudagraph_support = AttentionCGSupport.UNIFORM_BATCH

reorder_batch_threshold: int = 1

def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self._init_reorder_batch_threshold(1, supports_spec_as_decode=True)

assert isinstance(kv_cache_spec, MambaSpec)
self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config
self.speculative_config = vllm_config.speculative_config
self.kv_cache_spec = kv_cache_spec

if self.speculative_config:
self.num_spec = self.speculative_config.num_speculative_tokens
else:
self.num_spec = 0
self.use_spec_decode = self.num_spec > 0
self._init_reorder_batch_threshold(1, self.use_spec_decode)

self.use_full_cuda_graph = (
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
Expand Down
6 changes: 3 additions & 3 deletions vllm/v1/attention/backends/linear_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ class LinearAttentionMetadata:


class LinearAttentionMetadataBuilder(AttentionMetadataBuilder[LinearAttentionMetadata]):
reorder_batch_threshold: int = 1

def __init__(
self,
kv_cache_spec: AttentionSpec,
Expand All @@ -43,6 +41,7 @@ def __init__(
device: torch.device,
):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self._init_reorder_batch_threshold(1)
assert isinstance(kv_cache_spec, MambaSpec)

def build(
Expand All @@ -56,9 +55,10 @@ def build(

state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]

assert self._reorder_batch_threshold is not None
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
common_attn_metadata, decode_threshold=self._reorder_batch_threshold
)
)

Expand Down
4 changes: 3 additions & 1 deletion vllm/v1/attention/backends/mamba1_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(
device: torch.device,
):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self._init_reorder_batch_threshold(1)
assert isinstance(kv_cache_spec, MambaSpec)

def build(
Expand All @@ -60,9 +61,10 @@ def build(
) -> Mamba1AttentionMetadata:
num_reqs = common_attn_metadata.num_reqs

assert self._reorder_batch_threshold is not None
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
common_attn_metadata, decode_threshold=self._reorder_batch_threshold
)
)

Expand Down
3 changes: 2 additions & 1 deletion vllm/v1/attention/backends/mamba2_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,10 @@ def build(
block_idx_last_scheduled_token = None
block_idx_last_computed_token = None

assert self._reorder_batch_threshold is not None
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
common_attn_metadata, decode_threshold=self._reorder_batch_threshold
)
)

Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/attention/backends/mamba_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@


class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
reorder_batch_threshold: int = 1
_cudagraph_support: ClassVar[AttentionCGSupport] = (
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
)
Expand All @@ -32,6 +31,7 @@ def __init__(
device: torch.device,
):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self._init_reorder_batch_threshold(1)

assert isinstance(kv_cache_spec, MambaSpec)
self.compilation_config = vllm_config.compilation_config
Expand Down
22 changes: 8 additions & 14 deletions vllm/v1/attention/backends/mla/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,13 +468,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
# speculative decoding is enabled.
query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.SINGLE_ONLY

# The threshold for reordering the batch into decode and prefill requests.
# If > 1, the batch will be reordered such that requests with
# query length <= threshold are classified as decode requests.
# Use `query_len_support` (above) to set this automatically
# when speculative decoding is enabled.
reorder_batch_threshold: int = 1

@staticmethod
def determine_chunked_prefill_workspace_size(vllm_config: VllmConfig) -> int:
scheduler_config = vllm_config.scheduler_config
Expand Down Expand Up @@ -513,7 +506,6 @@ def __init__(
vllm_config: VllmConfig,
device: torch.device,
metadata_cls: type[M] | None = None,
supports_dcp_with_varlen: bool = False,
):
self.metadata_cls = (
metadata_cls if metadata_cls is not None else MLACommonMetadata
Expand Down Expand Up @@ -613,14 +605,14 @@ def __init__(

supports_spec_decode = self.query_len_support != QueryLenSupport.SINGLE_ONLY
self._init_reorder_batch_threshold(
self.reorder_batch_threshold, supports_spec_decode, supports_dcp_with_varlen
1, supports_spec_as_decode=supports_spec_decode
)

# Validate consistency between query_len_support and reorder_batch_threshold
if self.query_len_support == QueryLenSupport.SINGLE_ONLY:
assert self.reorder_batch_threshold == 1, (
assert self._reorder_batch_threshold == 1, (
f"reorder_batch_threshold must be 1 when query_len_support is "
f"SINGLE_ONLY, got {self.reorder_batch_threshold}"
f"SINGLE_ONLY, got {self._reorder_batch_threshold}"
)

def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata):
Expand Down Expand Up @@ -723,12 +715,13 @@ def build_for_cudagraph_capture(
Currently, only decode is supported for full cudagraphs with MLA.
"""
m = common_attn_metadata
assert m.num_reqs <= (m.num_actual_tokens * self.reorder_batch_threshold), (
assert self._reorder_batch_threshold is not None
assert m.num_reqs <= (m.num_actual_tokens * self._reorder_batch_threshold), (
"MLA only supports decode-only full CUDAGraph capture. "
"Make sure all cudagraph capture sizes <= max_num_seq."
)

assert m.max_query_len <= self.reorder_batch_threshold # decode only
assert m.max_query_len <= self._reorder_batch_threshold # decode only

return self.build(0, m)

Expand Down Expand Up @@ -760,10 +753,11 @@ def build(

num_computed_tokens_cpu = common_attn_metadata.seq_lens_cpu - query_seq_lens_cpu

assert self._reorder_batch_threshold is not None
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
common_attn_metadata,
decode_threshold=self.reorder_batch_threshold,
decode_threshold=self._reorder_batch_threshold,
require_uniform=(self.query_len_support != QueryLenSupport.VARLEN),
)
)
Expand Down
4 changes: 2 additions & 2 deletions vllm/v1/attention/backends/mla/flashattn_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ class FlashAttnMLAMetadata(MLACommonMetadata[FlashAttnMLADecodeMetadata]):
class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]):
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.VARLEN
reorder_batch_threshold: int = 512 # process small prefills with decode pathway

def __init__(
self,
Expand All @@ -109,8 +108,9 @@ def __init__(
vllm_config,
device,
FlashAttnMLAMetadata,
supports_dcp_with_varlen=True,
)
self._init_reorder_batch_threshold(512, supports_dcp_with_varlen=True)

self.max_num_splits = 0 # No upper bound on the number of splits.
self.fa_aot_schedule = get_flash_attn_version() == 3

Expand Down
5 changes: 3 additions & 2 deletions vllm/v1/attention/backends/mla/flashmla.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,6 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM
reorder_batch_threshold: int = 128 # process small prefills with decode pathway
# ^ TODO(matt): tune this

def __init__(
self,
Expand All @@ -111,6 +109,9 @@ def __init__(
super().__init__(
kv_cache_spec, layer_names, vllm_config, device, FlashMLAMetadata
)
# process small prefills with decode pathway
self._init_reorder_batch_threshold(128)
# ^ TODO(matt): tune this

self.num_q_heads = vllm_config.model_config.get_num_attention_heads(
vllm_config.parallel_config
Expand Down
5 changes: 3 additions & 2 deletions vllm/v1/attention/backends/rocm_aiter_fa.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,6 @@ class AiterFlashAttentionMetadataBuilder(
AttentionMetadataBuilder[AiterFlashAttentionMetadata]
):
_cudagraph_support = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
reorder_batch_threshold: int = 1

def __init__(
self,
Expand All @@ -262,6 +261,7 @@ def __init__(
device: torch.device,
):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self._init_reorder_batch_threshold(1)

self.model_config = vllm_config.model_config
self.parallel_config = vllm_config.parallel_config
Expand Down Expand Up @@ -301,9 +301,10 @@ def build(
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> "AiterFlashAttentionMetadata":
assert self._reorder_batch_threshold is not None
split_ret = split_decodes_prefills_and_extends(
common_attn_metadata,
decode_threshold=self.reorder_batch_threshold,
decode_threshold=self._reorder_batch_threshold,
)

(
Expand Down
3 changes: 2 additions & 1 deletion vllm/v1/attention/backends/short_conv_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,10 @@ def build(
# for causal_conv1d
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None

assert self._reorder_batch_threshold is not None
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
common_attn_metadata, decode_threshold=self._reorder_batch_threshold
)
)

Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/attention/backends/tree_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def __init__(
device=device,
)

self.reorder_batch_threshold = self.tree_attn_bias.shape[0]
self._init_reorder_batch_threshold(self.tree_attn_bias.shape[0])

def build(
self,
Expand Down
23 changes: 14 additions & 9 deletions vllm/v1/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,10 +246,6 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
# Does this backend/builder support CUDA Graphs for attention (default: no).
# Do not access directly. Call get_cudagraph_support() instead.
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER
# Does this backend/builder reorder the batch?
# If not, set this to None. Otherwise set it to the query
# length that will be pulled into the front of the batch.
reorder_batch_threshold: int | None = None

@abstractmethod
def __init__(
Expand All @@ -264,6 +260,7 @@ def __init__(
self.vllm_config = vllm_config
self.device = device

# Does this backend/builder support CUDA Graphs for attention (default: no).
@classmethod
def get_cudagraph_support(
cls: type["AttentionMetadataBuilder"],
Expand All @@ -273,14 +270,22 @@ def get_cudagraph_support(
"""Get the cudagraph support level of this builder class."""
return cls._cudagraph_support

# Does this backend/builder reorder the batch?
# If not, set this to None. Otherwise set it to the query
# length that will be pulled into the front of the batch.
def get_reorder_batch_threshold(self) -> int | None:
if hasattr(self, "_reorder_batch_threshold"):
return self._reorder_batch_threshold
return None

def _init_reorder_batch_threshold(
self,
reorder_batch_threshold: int | None = 1,
supports_spec_as_decode: bool = False,
supports_dcp_with_varlen: bool = False,
) -> None:
self.reorder_batch_threshold = reorder_batch_threshold
if self.reorder_batch_threshold is not None and supports_spec_as_decode:
self._reorder_batch_threshold = reorder_batch_threshold
if self._reorder_batch_threshold is not None and supports_spec_as_decode:
# If the backend supports spec-as-decode kernels, then we can set
# the reorder_batch_threshold based on the number of speculative
# tokens from the config.
Expand All @@ -289,16 +294,16 @@ def _init_reorder_batch_threshold(
speculative_config is not None
and speculative_config.num_speculative_tokens is not None
):
self.reorder_batch_threshold = max(
self.reorder_batch_threshold,
self._reorder_batch_threshold = max(
self._reorder_batch_threshold,
1 + speculative_config.num_speculative_tokens,
)

if (
self.vllm_config.parallel_config.decode_context_parallel_size > 1
and not supports_dcp_with_varlen
):
self.reorder_batch_threshold = 1
self._reorder_batch_threshold = 1

@abstractmethod
def build(
Expand Down
Loading
Loading