Skip to content
2 changes: 1 addition & 1 deletion cmake/external_projects/vllm_flash_attn.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ else()
FetchContent_Declare(
vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG 58e0626a692f09241182582659e3bf8f16472659
GIT_TAG 71bb26f6295449be880344b93b51791cc009237d
GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
Expand Down
4 changes: 2 additions & 2 deletions tests/kernels/attention/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
)

NUM_HEADS = [(4, 4), (8, 2)]
HEAD_SIZES = [128, 256]
HEAD_SIZES = [40, 72, 80, 128, 256]
BLOCK_SIZES = [16]
DTYPES = [torch.bfloat16]
QDTYPES = [None, torch.float8_e4m3fn]
# one value large enough to test overflow in index calculation.
# one value small enough to test the schema op check
NUM_BLOCKS = [32768, 2048]
SOFT_CAPS = [None, 50.0]
SOFT_CAPS = [None]
SLIDING_WINDOWS = [None, 256]


Expand Down
30 changes: 1 addition & 29 deletions tests/kernels/attention/test_mha_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,38 +62,10 @@ def test_mha_attn_platform(device: str):
assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN

# Test CUDA with head_size=72 (not divisible by 32)
# - with upstream FA not available
# - should use xformers
with (
patch("vllm.attention.layer.current_platform", CudaPlatform()),
patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()),
patch(
"vllm.attention.layer.check_upstream_fa_availability",
return_value=False,
),
):
attn = MultiHeadAttention(16, 72, scale=1)
assert attn.attn_backend == AttentionBackendEnum.XFORMERS

# Test CUDA with head_size=72 (not divisible by 32)
# - with upstream FA available
# - should use upstream FA
# - should use vLLM's FlashAttention
with (
patch("vllm.attention.layer.current_platform", CudaPlatform()),
patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()),
patch(
"vllm.attention.layer.check_upstream_fa_availability", return_value=True
),
patch.dict(
"sys.modules",
{
"flash_attn": type(
"MockFlashAttn",
(),
{"flash_attn_varlen_func": lambda *args, **kwargs: None},
)()
},
),
):
attn = MultiHeadAttention(16, 72, scale=1)
assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN
Expand Down
21 changes: 9 additions & 12 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,24 +267,21 @@ def get_vit_attn_backend(
) -> "AttentionBackendEnum":
from vllm.attention.backends.registry import AttentionBackendEnum

# For Blackwell GPUs, force TORCH_SDPA for now.
# See https://github.com/facebookresearch/xformers/issues/1317#issuecomment-3199392579 # noqa: E501
if cls.has_device_capability(100):
return AttentionBackendEnum.TORCH_SDPA

if dtype not in (torch.float16, torch.bfloat16):
return AttentionBackendEnum.XFORMERS

if cls.has_device_capability(80):
# Try FlashAttention first
try:
backend_class = AttentionBackendEnum.FLASH_ATTN.get_class()
if backend_class.supports_head_size(
head_size
) and backend_class.supports_dtype(dtype):
return AttentionBackendEnum.FLASH_ATTN
else:
return AttentionBackendEnum.XFORMERS
except ImportError:
pass

if cls.has_device_capability(100):
# xFormers doesn't support Blackwell, fall back to SDPA
# See https://github.com/facebookresearch/xformers/issues/1317#issuecomment-3199392579 # noqa: E501
return AttentionBackendEnum.TORCH_SDPA
else:
# Fallback for Volta/Turing GPUs or FA not supported
return AttentionBackendEnum.XFORMERS

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ def get_fp8_dtype_for_flashattn(kv_cache_dtype: str) -> torch.dtype:
raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")

@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
def supports_head_size(cls, head_size: int) -> bool:
return head_size % 8 == 0 and head_size <= 256

@classmethod
def supports_kv_cache_dtype(cls, kv_cache_dtype: CacheDType | None) -> bool:
Expand Down