diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index f1254352c058..3fd181e7040a 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -159,11 +159,11 @@ def build( num_decode_tokens = 0 if self.use_sdpa_prefill and causal: # Decoder, need reorder and truncate - assert self.reorder_batch_threshold + assert self._reorder_batch_threshold is not None (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) = ( split_decodes_and_prefills( common_attn_metadata, - decode_threshold=self.reorder_batch_threshold, + decode_threshold=self._reorder_batch_threshold, require_uniform=True, ) ) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 4da1637d96eb..67794db9145e 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -404,8 +404,6 @@ class FlashInferMetadata: class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): - reorder_batch_threshold: int = 1 - def __init__( self, kv_cache_spec: AttentionSpec, @@ -651,10 +649,12 @@ def build( ) -> FlashInferMetadata: num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens + + assert self._reorder_batch_threshold is not None num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( split_decodes_and_prefills( common_attn_metadata, - decode_threshold=self.reorder_batch_threshold, + decode_threshold=self._reorder_batch_threshold, require_uniform=True, ) ) @@ -752,7 +752,8 @@ def build( paged_kv_last_page_len_np, ) - uses_spec_reorder = self.reorder_batch_threshold > 1 + uses_spec_reorder = self._reorder_batch_threshold > 1 + prefill_use_trtllm = use_trtllm_attention( self.num_qo_heads, self.num_kv_heads, diff --git a/vllm/v1/attention/backends/gdn_attn.py b/vllm/v1/attention/backends/gdn_attn.py index 69b5a6fb4856..29e984b6f595 100644 --- a/vllm/v1/attention/backends/gdn_attn.py +++ b/vllm/v1/attention/backends/gdn_attn.py @@ -61,8 +61,6 @@ class GDNAttentionMetadata: class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]): _cudagraph_support = AttentionCGSupport.UNIFORM_BATCH - reorder_batch_threshold: int = 1 - def __init__( self, kv_cache_spec: AttentionSpec, @@ -70,17 +68,20 @@ def __init__( vllm_config: VllmConfig, device: torch.device, ): + super().__init__(kv_cache_spec, layer_names, vllm_config, device) + self._init_reorder_batch_threshold(1, supports_spec_as_decode=True) + assert isinstance(kv_cache_spec, MambaSpec) self.vllm_config = vllm_config self.compilation_config = vllm_config.compilation_config self.speculative_config = vllm_config.speculative_config self.kv_cache_spec = kv_cache_spec + if self.speculative_config: self.num_spec = self.speculative_config.num_speculative_tokens else: self.num_spec = 0 self.use_spec_decode = self.num_spec > 0 - self._init_reorder_batch_threshold(1, self.use_spec_decode) self.use_full_cuda_graph = ( self.compilation_config.cudagraph_mode.has_full_cudagraphs() diff --git a/vllm/v1/attention/backends/linear_attn.py b/vllm/v1/attention/backends/linear_attn.py index 1900c50849ec..0ae99e1dcc3a 100644 --- a/vllm/v1/attention/backends/linear_attn.py +++ b/vllm/v1/attention/backends/linear_attn.py @@ -33,8 +33,6 @@ class LinearAttentionMetadata: class LinearAttentionMetadataBuilder(AttentionMetadataBuilder[LinearAttentionMetadata]): - reorder_batch_threshold: int = 1 - def __init__( self, kv_cache_spec: AttentionSpec, @@ -43,6 +41,7 @@ def __init__( device: torch.device, ): super().__init__(kv_cache_spec, layer_names, vllm_config, device) + self._init_reorder_batch_threshold(1) assert isinstance(kv_cache_spec, MambaSpec) def build( @@ -56,9 +55,10 @@ def build( state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] + assert self._reorder_batch_threshold is not None num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( split_decodes_and_prefills( - common_attn_metadata, decode_threshold=self.reorder_batch_threshold + common_attn_metadata, decode_threshold=self._reorder_batch_threshold ) ) diff --git a/vllm/v1/attention/backends/mamba1_attn.py b/vllm/v1/attention/backends/mamba1_attn.py index 8e949e53330c..c54b63f6a1a2 100644 --- a/vllm/v1/attention/backends/mamba1_attn.py +++ b/vllm/v1/attention/backends/mamba1_attn.py @@ -50,6 +50,7 @@ def __init__( device: torch.device, ): super().__init__(kv_cache_spec, layer_names, vllm_config, device) + self._init_reorder_batch_threshold(1) assert isinstance(kv_cache_spec, MambaSpec) def build( @@ -60,9 +61,10 @@ def build( ) -> Mamba1AttentionMetadata: num_reqs = common_attn_metadata.num_reqs + assert self._reorder_batch_threshold is not None num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( split_decodes_and_prefills( - common_attn_metadata, decode_threshold=self.reorder_batch_threshold + common_attn_metadata, decode_threshold=self._reorder_batch_threshold ) ) diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index 888734e5d2b6..0512dfa080bd 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -195,9 +195,10 @@ def build( block_idx_last_scheduled_token = None block_idx_last_computed_token = None + assert self._reorder_batch_threshold is not None num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( split_decodes_and_prefills( - common_attn_metadata, decode_threshold=self.reorder_batch_threshold + common_attn_metadata, decode_threshold=self._reorder_batch_threshold ) ) diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 0d875565fc99..4a54f83ce925 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -19,7 +19,6 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): - reorder_batch_threshold: int = 1 _cudagraph_support: ClassVar[AttentionCGSupport] = ( AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE ) @@ -32,6 +31,7 @@ def __init__( device: torch.device, ): super().__init__(kv_cache_spec, layer_names, vllm_config, device) + self._init_reorder_batch_threshold(1) assert isinstance(kv_cache_spec, MambaSpec) self.compilation_config = vllm_config.compilation_config diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 2ccdd1f143ce..64fcee204c88 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -468,13 +468,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): # speculative decoding is enabled. query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.SINGLE_ONLY - # The threshold for reordering the batch into decode and prefill requests. - # If > 1, the batch will be reordered such that requests with - # query length <= threshold are classified as decode requests. - # Use `query_len_support` (above) to set this automatically - # when speculative decoding is enabled. - reorder_batch_threshold: int = 1 - @staticmethod def determine_chunked_prefill_workspace_size(vllm_config: VllmConfig) -> int: scheduler_config = vllm_config.scheduler_config @@ -513,7 +506,6 @@ def __init__( vllm_config: VllmConfig, device: torch.device, metadata_cls: type[M] | None = None, - supports_dcp_with_varlen: bool = False, ): self.metadata_cls = ( metadata_cls if metadata_cls is not None else MLACommonMetadata @@ -613,14 +605,14 @@ def __init__( supports_spec_decode = self.query_len_support != QueryLenSupport.SINGLE_ONLY self._init_reorder_batch_threshold( - self.reorder_batch_threshold, supports_spec_decode, supports_dcp_with_varlen + 1, supports_spec_as_decode=supports_spec_decode ) # Validate consistency between query_len_support and reorder_batch_threshold if self.query_len_support == QueryLenSupport.SINGLE_ONLY: - assert self.reorder_batch_threshold == 1, ( + assert self._reorder_batch_threshold == 1, ( f"reorder_batch_threshold must be 1 when query_len_support is " - f"SINGLE_ONLY, got {self.reorder_batch_threshold}" + f"SINGLE_ONLY, got {self._reorder_batch_threshold}" ) def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata): @@ -723,12 +715,13 @@ def build_for_cudagraph_capture( Currently, only decode is supported for full cudagraphs with MLA. """ m = common_attn_metadata - assert m.num_reqs <= (m.num_actual_tokens * self.reorder_batch_threshold), ( + assert self._reorder_batch_threshold is not None + assert m.num_reqs <= (m.num_actual_tokens * self._reorder_batch_threshold), ( "MLA only supports decode-only full CUDAGraph capture. " "Make sure all cudagraph capture sizes <= max_num_seq." ) - assert m.max_query_len <= self.reorder_batch_threshold # decode only + assert m.max_query_len <= self._reorder_batch_threshold # decode only return self.build(0, m) @@ -760,10 +753,11 @@ def build( num_computed_tokens_cpu = common_attn_metadata.seq_lens_cpu - query_seq_lens_cpu + assert self._reorder_batch_threshold is not None num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( split_decodes_and_prefills( common_attn_metadata, - decode_threshold=self.reorder_batch_threshold, + decode_threshold=self._reorder_batch_threshold, require_uniform=(self.query_len_support != QueryLenSupport.VARLEN), ) ) diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index 7794e89cc0a9..3d422fa5cd27 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -94,7 +94,6 @@ class FlashAttnMLAMetadata(MLACommonMetadata[FlashAttnMLADecodeMetadata]): class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]): _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.VARLEN - reorder_batch_threshold: int = 512 # process small prefills with decode pathway def __init__( self, @@ -109,8 +108,9 @@ def __init__( vllm_config, device, FlashAttnMLAMetadata, - supports_dcp_with_varlen=True, ) + self._init_reorder_batch_threshold(512, supports_dcp_with_varlen=True) + self.max_num_splits = 0 # No upper bound on the number of splits. self.fa_aot_schedule = get_flash_attn_version() == 3 diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 3aab1f9bb7fb..9e991164594b 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -98,8 +98,6 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]): class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM - reorder_batch_threshold: int = 128 # process small prefills with decode pathway - # ^ TODO(matt): tune this def __init__( self, @@ -111,6 +109,9 @@ def __init__( super().__init__( kv_cache_spec, layer_names, vllm_config, device, FlashMLAMetadata ) + # process small prefills with decode pathway + self._init_reorder_batch_threshold(128) + # ^ TODO(matt): tune this self.num_q_heads = vllm_config.model_config.get_num_attention_heads( vllm_config.parallel_config diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index ad454daa582e..6cf0832ac54e 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -252,7 +252,6 @@ class AiterFlashAttentionMetadataBuilder( AttentionMetadataBuilder[AiterFlashAttentionMetadata] ): _cudagraph_support = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE - reorder_batch_threshold: int = 1 def __init__( self, @@ -262,6 +261,7 @@ def __init__( device: torch.device, ): super().__init__(kv_cache_spec, layer_names, vllm_config, device) + self._init_reorder_batch_threshold(1) self.model_config = vllm_config.model_config self.parallel_config = vllm_config.parallel_config @@ -301,9 +301,10 @@ def build( common_attn_metadata: CommonAttentionMetadata, fast_build: bool = False, ) -> "AiterFlashAttentionMetadata": + assert self._reorder_batch_threshold is not None split_ret = split_decodes_prefills_and_extends( common_attn_metadata, - decode_threshold=self.reorder_batch_threshold, + decode_threshold=self._reorder_batch_threshold, ) ( diff --git a/vllm/v1/attention/backends/short_conv_attn.py b/vllm/v1/attention/backends/short_conv_attn.py index de0cb73db091..1ecdf9a1bb88 100644 --- a/vllm/v1/attention/backends/short_conv_attn.py +++ b/vllm/v1/attention/backends/short_conv_attn.py @@ -53,9 +53,10 @@ def build( # for causal_conv1d nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None + assert self._reorder_batch_threshold is not None num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( split_decodes_and_prefills( - common_attn_metadata, decode_threshold=self.reorder_batch_threshold + common_attn_metadata, decode_threshold=self._reorder_batch_threshold ) ) diff --git a/vllm/v1/attention/backends/tree_attn.py b/vllm/v1/attention/backends/tree_attn.py index 1bf38ed225a4..c5d90dbb8836 100644 --- a/vllm/v1/attention/backends/tree_attn.py +++ b/vllm/v1/attention/backends/tree_attn.py @@ -165,7 +165,7 @@ def __init__( device=device, ) - self.reorder_batch_threshold = self.tree_attn_bias.shape[0] + self._init_reorder_batch_threshold(self.tree_attn_bias.shape[0]) def build( self, diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 578153cda786..859e4772f6f9 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -246,10 +246,6 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): # Does this backend/builder support CUDA Graphs for attention (default: no). # Do not access directly. Call get_cudagraph_support() instead. _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER - # Does this backend/builder reorder the batch? - # If not, set this to None. Otherwise set it to the query - # length that will be pulled into the front of the batch. - reorder_batch_threshold: int | None = None @abstractmethod def __init__( @@ -264,6 +260,7 @@ def __init__( self.vllm_config = vllm_config self.device = device + # Does this backend/builder support CUDA Graphs for attention (default: no). @classmethod def get_cudagraph_support( cls: type["AttentionMetadataBuilder"], @@ -273,14 +270,22 @@ def get_cudagraph_support( """Get the cudagraph support level of this builder class.""" return cls._cudagraph_support + # Does this backend/builder reorder the batch? + # If not, set this to None. Otherwise set it to the query + # length that will be pulled into the front of the batch. + def get_reorder_batch_threshold(self) -> int | None: + if hasattr(self, "_reorder_batch_threshold"): + return self._reorder_batch_threshold + return None + def _init_reorder_batch_threshold( self, reorder_batch_threshold: int | None = 1, supports_spec_as_decode: bool = False, supports_dcp_with_varlen: bool = False, ) -> None: - self.reorder_batch_threshold = reorder_batch_threshold - if self.reorder_batch_threshold is not None and supports_spec_as_decode: + self._reorder_batch_threshold = reorder_batch_threshold + if self._reorder_batch_threshold is not None and supports_spec_as_decode: # If the backend supports spec-as-decode kernels, then we can set # the reorder_batch_threshold based on the number of speculative # tokens from the config. @@ -289,8 +294,8 @@ def _init_reorder_batch_threshold( speculative_config is not None and speculative_config.num_speculative_tokens is not None ): - self.reorder_batch_threshold = max( - self.reorder_batch_threshold, + self._reorder_batch_threshold = max( + self._reorder_batch_threshold, 1 + speculative_config.num_speculative_tokens, ) @@ -298,7 +303,7 @@ def _init_reorder_batch_threshold( self.vllm_config.parallel_config.decode_context_parallel_size > 1 and not supports_dcp_with_varlen ): - self.reorder_batch_threshold = 1 + self._reorder_batch_threshold = 1 @abstractmethod def build( diff --git a/vllm/v1/attention/backends/xformers.py b/vllm/v1/attention/backends/xformers.py index d15d79417cc6..e854bf984197 100644 --- a/vllm/v1/attention/backends/xformers.py +++ b/vllm/v1/attention/backends/xformers.py @@ -184,8 +184,6 @@ def decode_metadata(self) -> Optional["XFormersAttentionMetadata"]: class XFormersAttentionMetadataBuilder( AttentionMetadataBuilder[XFormersAttentionMetadata] ): - reorder_batch_threshold: int = 1 - def __init__( self, kv_cache_spec: AttentionSpec, @@ -194,6 +192,7 @@ def __init__( device: torch.device, ): super().__init__(kv_cache_spec, layer_names, vllm_config, device) + self._init_reorder_batch_threshold(1) assert XFORMERS_AVAILABLE self.block_size = kv_cache_spec.block_size @@ -206,9 +205,10 @@ def build( common_attn_metadata: CommonAttentionMetadata, fast_build: bool = False, ) -> XFormersAttentionMetadata: + assert self._reorder_batch_threshold is not None num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( split_decodes_and_prefills( - common_attn_metadata, decode_threshold=self.reorder_batch_threshold + common_attn_metadata, decode_threshold=self._reorder_batch_threshold ) ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9b3e5b668aab..4d39a20e4e76 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -628,16 +628,6 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: return if self.reorder_batch_threshold is not None: - # NOTE(lucas): currently no backend supports the custom masking - # required for DCP with q_len > 1, so we assert here. Remove this - # assert once the custom mask is support is added to FA3. - if ( - self.dcp_world_size > 1 - and envs.VLLM_ATTENTION_BACKEND != "FLASH_ATTN_MLA" - ): - assert self.reorder_batch_threshold == 1, ( - "DCP not support reorder_batch_threshold > 1 now." - ) reorder_batch_to_split_decodes_and_prefills( self.input_batch, scheduler_output, @@ -4348,7 +4338,7 @@ def calculate_reorder_batch_threshold(self) -> None: min_none_high = lambda a, b: a if b is None else b if a is None else min(a, b) reorder_batch_thresholds = [ - group.get_metadata_builder().reorder_batch_threshold + group.get_metadata_builder().get_reorder_batch_threshold() for group in self._attn_group_iterator() ] # If there are no attention groups (attention-free model) or no backend