Skip to content

Commit 412e153

Browse files
[Feature] Allow configuring FlashInfer workspace size (#28269)
Signed-off-by: Max Hu <hyoung2991@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent e5f599d commit 412e153

File tree

3 files changed

+16
-12
lines changed

3 files changed

+16
-12
lines changed

vllm/envs.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@
159159
VLLM_USE_FLASHINFER_MOE_FP8: bool = False
160160
VLLM_USE_FLASHINFER_MOE_FP4: bool = False
161161
VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput", "latency"] = "latency"
162+
VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE: int = 394 * 1024 * 1024
162163
VLLM_XGRAMMAR_CACHE_MB: int = 0
163164
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
164165
VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False
@@ -1237,6 +1238,10 @@ def get_vllm_port() -> int | None:
12371238
"VLLM_FLASHINFER_MOE_BACKEND": env_with_choices(
12381239
"VLLM_FLASHINFER_MOE_BACKEND", "latency", ["throughput", "latency"]
12391240
),
1241+
# Control the workspace buffer size for the FlashInfer backend.
1242+
"VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE": lambda: int(
1243+
os.getenv("VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE", str(394 * 1024 * 1024))
1244+
),
12401245
# Control the maximum number of tokens per expert supported by the
12411246
# NVFP4 MoE CUTLASS Kernel. This value is used to create a buffer for
12421247
# the blockscale tensor of activations NVFP4 Quantization.
@@ -1583,6 +1588,7 @@ def compute_hash() -> str:
15831588
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8",
15841589
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS",
15851590
"VLLM_USE_FLASHINFER_MOE_MXFP4_BF16",
1591+
"VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE",
15861592
"VLLM_USE_CUDNN_PREFILL",
15871593
"VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL",
15881594
"VLLM_USE_TRTLLM_ATTENTION",

vllm/v1/attention/backends/flashinfer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from flashinfer.prefill import trtllm_batch_context_with_kv_cache
1717
from flashinfer.utils import FP4Tensor
1818

19+
from vllm import envs
1920
from vllm.attention.backends.abstract import (
2021
AttentionBackend,
2122
AttentionImpl,
@@ -55,7 +56,6 @@
5556
)
5657
from vllm.v1.kv_cache_interface import AttentionSpec
5758

58-
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
5959
FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT = 2048 * 1024 * 1024
6060

6161
FP8_DTYPE = current_platform.fp8_dtype()
@@ -70,7 +70,7 @@ def _get_trtllm_gen_workspace_buffer():
7070
global trtllm_gen_workspace_buffer
7171
if trtllm_gen_workspace_buffer is None:
7272
trtllm_gen_workspace_buffer = torch.zeros(
73-
FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device="cuda"
73+
envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device="cuda"
7474
)
7575
return trtllm_gen_workspace_buffer
7676

@@ -414,7 +414,7 @@ def __init__(
414414

415415
def _get_workspace_buffer(self):
416416
if self._workspace_buffer is None:
417-
buffer_size = FLASHINFER_WORKSPACE_BUFFER_SIZE
417+
buffer_size = envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE
418418
if vllm_is_batch_invariant():
419419
buffer_size = FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT
420420
self._workspace_buffer = torch.zeros(

vllm/v1/attention/backends/mla/common.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -196,8 +196,8 @@
196196
import torch
197197
from tqdm import tqdm
198198

199-
import vllm.envs as envs
200199
from vllm import _custom_ops as ops
200+
from vllm import envs
201201
from vllm._aiter_ops import rocm_aiter_ops
202202
from vllm.attention.backends.abstract import (
203203
AttentionBackend,
@@ -453,12 +453,6 @@ def use_trtllm_ragged_deepseek_prefill() -> bool:
453453
)
454454

455455

456-
# Currently 394MB, this can be tuned based on GEMM sizes used.
457-
# Chosen to be the same as sglang:
458-
# https://github.com/sgl-project/sglang/blob/766392c6bda2558b61ce6d1c1bfd8081a549e1f1/python/sglang/global_config.py#L37
459-
FLASHINFER_WORKSPACE_BUFFER_SIZE = 394 * 1024 * 1024
460-
461-
462456
class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
463457
"""
464458
NOTE: Please read the comment at the top of the file before trying to
@@ -590,7 +584,9 @@ def __init__(
590584

591585
if self._use_fi_prefill:
592586
self._workspace_buffer = torch.empty(
593-
FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device=device
587+
envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE,
588+
dtype=torch.uint8,
589+
device=device,
594590
)
595591

596592
self._fi_prefill_main: BatchPrefillWithRaggedKVCacheWrapper | None = None
@@ -602,7 +598,9 @@ def __init__(
602598

603599
if self._use_trtllm_ragged_prefill:
604600
self._workspace_buffer = torch.empty(
605-
FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device=device
601+
envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE,
602+
dtype=torch.uint8,
603+
device=device,
606604
)
607605

608606
if self._use_cudnn_prefill:

0 commit comments

Comments
 (0)