diff --git a/tests/kernels/attention/test_flashinfer_mla_decode.py b/tests/kernels/attention/test_flashinfer_mla_decode.py index 02225432f77f..35cbdbfd8ad7 100644 --- a/tests/kernels/attention/test_flashinfer_mla_decode.py +++ b/tests/kernels/attention/test_flashinfer_mla_decode.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os import pytest import torch import torch.nn.functional as F @@ -8,7 +9,8 @@ from vllm.platforms import current_platform -FLASHINFER_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024 +FLASHINFER_WORKSPACE_BUFFER_SIZE = int( + os.environ.get("VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE", 128 * 1024 * 1024)) if not current_platform.has_device_capability(100): pytest.skip( diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index cb092aa74e7f..c0facad8f203 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -3,6 +3,7 @@ """Attention layer with FlashInfer.""" from __future__ import annotations +import os from dataclasses import dataclass from typing import ClassVar, Optional, Union @@ -41,7 +42,12 @@ # yapf: enable from vllm.v1.kv_cache_interface import AttentionSpec -FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 +_default_workspace_size = 256 * 1024 * 1024 +_env_val = os.environ.get("VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE", _default_workspace_size) +try: + FLASHINFER_WORKSPACE_BUFFER_SIZE = int(_env_val) +except (ValueError, TypeError): + FLASHINFER_WORKSPACE_BUFFER_SIZE = _default_workspace_size FP8_DTYPE = current_platform.fp8_dtype() FP4_DTYPE = torch.uint8 diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 5b307810de93..3650dd7fb8e7 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -188,6 +188,7 @@ """ import functools +import os from abc import abstractmethod from dataclasses import dataclass, field from typing import ClassVar, Generic, Optional, TypeVar, Union @@ -426,7 +427,8 @@ def use_cudnn_prefill() -> bool: # Currently 394MB, this can be tuned based on GEMM sizes used. # Chosen to be the same as sglang: # https://github.com/sgl-project/sglang/blob/766392c6bda2558b61ce6d1c1bfd8081a549e1f1/python/sglang/global_config.py#L37 -FLASHINFER_WORKSPACE_BUFFER_SIZE = 394 * 1024 * 1024 +FLASHINFER_WORKSPACE_BUFFER_SIZE = int( + os.environ.get("VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE", 394 * 1024 * 1024)) class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):