Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion tests/kernels/attention/test_flashinfer_mla_decode.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import pytest
import torch
import torch.nn.functional as F
Expand All @@ -8,7 +9,8 @@

from vllm.platforms import current_platform

FLASHINFER_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024
FLASHINFER_WORKSPACE_BUFFER_SIZE = int(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please move this to vllm/envs.py

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

os.environ.get("VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE", 128 * 1024 * 1024))

if not current_platform.has_device_capability(100):
pytest.skip(
Expand Down
4 changes: 3 additions & 1 deletion vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""Attention layer with FlashInfer."""
from __future__ import annotations

import os
from dataclasses import dataclass
from typing import ClassVar, Optional, Union

Expand Down Expand Up @@ -41,7 +42,8 @@
# yapf: enable
from vllm.v1.kv_cache_interface import AttentionSpec

FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
FLASHINFER_WORKSPACE_BUFFER_SIZE = int(
os.environ.get("VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE", 256 * 1024 * 1024))

FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
Expand Down
4 changes: 3 additions & 1 deletion vllm/v1/attention/backends/mla/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@
"""

import functools
import os
from abc import abstractmethod
from dataclasses import dataclass, field
from typing import ClassVar, Generic, Optional, TypeVar, Union
Expand Down Expand Up @@ -426,7 +427,8 @@ def use_cudnn_prefill() -> bool:
# Currently 394MB, this can be tuned based on GEMM sizes used.
# Chosen to be the same as sglang:
# https://github.com/sgl-project/sglang/blob/766392c6bda2558b61ce6d1c1bfd8081a549e1f1/python/sglang/global_config.py#L37
FLASHINFER_WORKSPACE_BUFFER_SIZE = 394 * 1024 * 1024
FLASHINFER_WORKSPACE_BUFFER_SIZE = int(
os.environ.get("VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE", 394 * 1024 * 1024))


class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
Expand Down
Loading