Skip to content
6 changes: 6 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@
VLLM_USE_FLASHINFER_MOE_FP8: bool = False
VLLM_USE_FLASHINFER_MOE_FP4: bool = False
VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput", "latency"] = "latency"
VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE: int = 394 * 1024 * 1024
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

394 is a strange default.

I assume you meant 384 which is 128+256

Suggested change
VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE: int = 394 * 1024 * 1024
VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE: int = 384 * 1024 * 1024

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is directly copied from here, a known use case.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mratsim Hi, yeah I saw the link but I'd rather keep it unchanged as I don't know the exact amount of memory it requires.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mratsim The point of this PR is so that you could set the value of VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE for your specific use case, which is often not the same as the default.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My point is that if copying SG-Lang default, we should copy it properly i.e. 384 not a strange 394 that will waste space in the allocator due to not being close to a power of 2 or sum of powers of 2:

  • 384 = 2⁸+2⁷
  • 394 = 2⁸+2⁷+2³+2

VLLM_XGRAMMAR_CACHE_MB: int = 0
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False
Expand Down Expand Up @@ -1237,6 +1238,10 @@ def get_vllm_port() -> int | None:
"VLLM_FLASHINFER_MOE_BACKEND": env_with_choices(
"VLLM_FLASHINFER_MOE_BACKEND", "latency", ["throughput", "latency"]
),
# Control the workspace buffer size for the FlashInfer backend.
"VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE": lambda: int(
os.getenv("VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE", str(394 * 1024 * 1024))
),
# Control the maximum number of tokens per expert supported by the
# NVFP4 MoE CUTLASS Kernel. This value is used to create a buffer for
# the blockscale tensor of activations NVFP4 Quantization.
Expand Down Expand Up @@ -1583,6 +1588,7 @@ def compute_hash() -> str:
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8",
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS",
"VLLM_USE_FLASHINFER_MOE_MXFP4_BF16",
"VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE",
"VLLM_USE_CUDNN_PREFILL",
"VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL",
"VLLM_USE_TRTLLM_ATTENTION",
Expand Down
6 changes: 3 additions & 3 deletions vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from flashinfer.prefill import trtllm_batch_context_with_kv_cache
from flashinfer.utils import FP4Tensor

from vllm import envs
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionImpl,
Expand Down Expand Up @@ -55,7 +56,6 @@
)
from vllm.v1.kv_cache_interface import AttentionSpec

FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT = 2048 * 1024 * 1024

FP8_DTYPE = current_platform.fp8_dtype()
Expand All @@ -70,7 +70,7 @@ def _get_trtllm_gen_workspace_buffer():
global trtllm_gen_workspace_buffer
if trtllm_gen_workspace_buffer is None:
trtllm_gen_workspace_buffer = torch.zeros(
FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device="cuda"
envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device="cuda"
)
return trtllm_gen_workspace_buffer

Expand Down Expand Up @@ -414,7 +414,7 @@ def __init__(

def _get_workspace_buffer(self):
if self._workspace_buffer is None:
buffer_size = FLASHINFER_WORKSPACE_BUFFER_SIZE
buffer_size = envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE
if vllm_is_batch_invariant():
buffer_size = FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT
self._workspace_buffer = torch.zeros(
Expand Down
16 changes: 7 additions & 9 deletions vllm/v1/attention/backends/mla/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,8 @@
import torch
from tqdm import tqdm

import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm import envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention.backends.abstract import (
AttentionBackend,
Expand Down Expand Up @@ -453,12 +453,6 @@ def use_trtllm_ragged_deepseek_prefill() -> bool:
)


# Currently 394MB, this can be tuned based on GEMM sizes used.
# Chosen to be the same as sglang:
# https://github.com/sgl-project/sglang/blob/766392c6bda2558b61ce6d1c1bfd8081a549e1f1/python/sglang/global_config.py#L37
FLASHINFER_WORKSPACE_BUFFER_SIZE = 394 * 1024 * 1024


class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
"""
NOTE: Please read the comment at the top of the file before trying to
Expand Down Expand Up @@ -590,7 +584,9 @@ def __init__(

if self._use_fi_prefill:
self._workspace_buffer = torch.empty(
FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device=device
envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE,
dtype=torch.uint8,
device=device,
)

self._fi_prefill_main: BatchPrefillWithRaggedKVCacheWrapper | None = None
Expand All @@ -602,7 +598,9 @@ def __init__(

if self._use_trtllm_ragged_prefill:
self._workspace_buffer = torch.empty(
FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device=device
envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE,
dtype=torch.uint8,
device=device,
)

if self._use_cudnn_prefill:
Expand Down