diff --git a/vllm/envs.py b/vllm/envs.py index 52a9671bc46e..5274c8ba1b24 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -159,6 +159,7 @@ VLLM_USE_FLASHINFER_MOE_FP8: bool = False VLLM_USE_FLASHINFER_MOE_FP4: bool = False VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput", "latency"] = "latency" + VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE: int = 394 * 1024 * 1024 VLLM_XGRAMMAR_CACHE_MB: int = 0 VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256 VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False @@ -1237,6 +1238,10 @@ def get_vllm_port() -> int | None: "VLLM_FLASHINFER_MOE_BACKEND": env_with_choices( "VLLM_FLASHINFER_MOE_BACKEND", "latency", ["throughput", "latency"] ), + # Control the workspace buffer size for the FlashInfer backend. + "VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE": lambda: int( + os.getenv("VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE", str(394 * 1024 * 1024)) + ), # Control the maximum number of tokens per expert supported by the # NVFP4 MoE CUTLASS Kernel. This value is used to create a buffer for # the blockscale tensor of activations NVFP4 Quantization. @@ -1583,6 +1588,7 @@ def compute_hash() -> str: "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS", "VLLM_USE_FLASHINFER_MOE_MXFP4_BF16", + "VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE", "VLLM_USE_CUDNN_PREFILL", "VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL", "VLLM_USE_TRTLLM_ATTENTION", diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 07a0ab41a9e0..18bbc3cc3c12 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -16,6 +16,7 @@ from flashinfer.prefill import trtllm_batch_context_with_kv_cache from flashinfer.utils import FP4Tensor +from vllm import envs from vllm.attention.backends.abstract import ( AttentionBackend, AttentionImpl, @@ -55,7 +56,6 @@ ) from vllm.v1.kv_cache_interface import AttentionSpec -FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT = 2048 * 1024 * 1024 FP8_DTYPE = current_platform.fp8_dtype() @@ -70,7 +70,7 @@ def _get_trtllm_gen_workspace_buffer(): global trtllm_gen_workspace_buffer if trtllm_gen_workspace_buffer is None: trtllm_gen_workspace_buffer = torch.zeros( - FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device="cuda" + envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device="cuda" ) return trtllm_gen_workspace_buffer @@ -414,7 +414,7 @@ def __init__( def _get_workspace_buffer(self): if self._workspace_buffer is None: - buffer_size = FLASHINFER_WORKSPACE_BUFFER_SIZE + buffer_size = envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE if vllm_is_batch_invariant(): buffer_size = FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT self._workspace_buffer = torch.zeros( diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 19bd102cb1e3..467c01cd9d06 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -196,8 +196,8 @@ import torch from tqdm import tqdm -import vllm.envs as envs from vllm import _custom_ops as ops +from vllm import envs from vllm._aiter_ops import rocm_aiter_ops from vllm.attention.backends.abstract import ( AttentionBackend, @@ -453,12 +453,6 @@ def use_trtllm_ragged_deepseek_prefill() -> bool: ) -# Currently 394MB, this can be tuned based on GEMM sizes used. -# Chosen to be the same as sglang: -# https://github.com/sgl-project/sglang/blob/766392c6bda2558b61ce6d1c1bfd8081a549e1f1/python/sglang/global_config.py#L37 -FLASHINFER_WORKSPACE_BUFFER_SIZE = 394 * 1024 * 1024 - - class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): """ NOTE: Please read the comment at the top of the file before trying to @@ -590,7 +584,9 @@ def __init__( if self._use_fi_prefill: self._workspace_buffer = torch.empty( - FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device=device + envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device=device, ) self._fi_prefill_main: BatchPrefillWithRaggedKVCacheWrapper | None = None @@ -602,7 +598,9 @@ def __init__( if self._use_trtllm_ragged_prefill: self._workspace_buffer = torch.empty( - FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device=device + envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device=device, ) if self._use_cudnn_prefill: