Skip to content

Commit 4f05ea4

Browse files
MatthewBonanniinkcherry
authored andcommitted
[Attention] FlashAttention ViT support, make default backend (vllm-project#28763)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: inkcherry <mingzhi.liu@amd.com>
1 parent 814843e commit 4f05ea4

File tree

5 files changed

+15
-46
lines changed

5 files changed

+15
-46
lines changed

cmake/external_projects/vllm_flash_attn.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ else()
3838
FetchContent_Declare(
3939
vllm-flash-attn
4040
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
41-
GIT_TAG 58e0626a692f09241182582659e3bf8f16472659
41+
GIT_TAG 71bb26f6295449be880344b93b51791cc009237d
4242
GIT_PROGRESS TRUE
4343
# Don't share the vllm-flash-attn build between build types
4444
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn

tests/kernels/attention/test_flash_attn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@
1313
)
1414

1515
NUM_HEADS = [(4, 4), (8, 2)]
16-
HEAD_SIZES = [128, 256]
16+
HEAD_SIZES = [40, 72, 80, 128, 256]
1717
BLOCK_SIZES = [16]
1818
DTYPES = [torch.bfloat16]
1919
QDTYPES = [None, torch.float8_e4m3fn]
2020
# one value large enough to test overflow in index calculation.
2121
# one value small enough to test the schema op check
2222
NUM_BLOCKS = [32768, 2048]
23-
SOFT_CAPS = [None, 50.0]
23+
SOFT_CAPS = [None]
2424
SLIDING_WINDOWS = [None, 256]
2525

2626

tests/kernels/attention/test_mha_attn.py

Lines changed: 1 addition & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -62,38 +62,10 @@ def test_mha_attn_platform(device: str):
6262
assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN
6363

6464
# Test CUDA with head_size=72 (not divisible by 32)
65-
# - with upstream FA not available
66-
# - should use xformers
67-
with (
68-
patch("vllm.attention.layer.current_platform", CudaPlatform()),
69-
patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()),
70-
patch(
71-
"vllm.attention.layer.check_upstream_fa_availability",
72-
return_value=False,
73-
),
74-
):
75-
attn = MultiHeadAttention(16, 72, scale=1)
76-
assert attn.attn_backend == AttentionBackendEnum.XFORMERS
77-
78-
# Test CUDA with head_size=72 (not divisible by 32)
79-
# - with upstream FA available
80-
# - should use upstream FA
65+
# - should use vLLM's FlashAttention
8166
with (
8267
patch("vllm.attention.layer.current_platform", CudaPlatform()),
8368
patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()),
84-
patch(
85-
"vllm.attention.layer.check_upstream_fa_availability", return_value=True
86-
),
87-
patch.dict(
88-
"sys.modules",
89-
{
90-
"flash_attn": type(
91-
"MockFlashAttn",
92-
(),
93-
{"flash_attn_varlen_func": lambda *args, **kwargs: None},
94-
)()
95-
},
96-
),
9769
):
9870
attn = MultiHeadAttention(16, 72, scale=1)
9971
assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN

vllm/platforms/cuda.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -267,24 +267,21 @@ def get_vit_attn_backend(
267267
) -> "AttentionBackendEnum":
268268
from vllm.attention.backends.registry import AttentionBackendEnum
269269

270-
# For Blackwell GPUs, force TORCH_SDPA for now.
271-
# See https://github.com/facebookresearch/xformers/issues/1317#issuecomment-3199392579 # noqa: E501
272-
if cls.has_device_capability(100):
273-
return AttentionBackendEnum.TORCH_SDPA
274-
275-
if dtype not in (torch.float16, torch.bfloat16):
276-
return AttentionBackendEnum.XFORMERS
277-
278-
if cls.has_device_capability(80):
270+
# Try FlashAttention first
271+
try:
279272
backend_class = AttentionBackendEnum.FLASH_ATTN.get_class()
280273
if backend_class.supports_head_size(
281274
head_size
282275
) and backend_class.supports_dtype(dtype):
283276
return AttentionBackendEnum.FLASH_ATTN
284-
else:
285-
return AttentionBackendEnum.XFORMERS
277+
except ImportError:
278+
pass
279+
280+
if cls.has_device_capability(100):
281+
# xFormers doesn't support Blackwell, fall back to SDPA
282+
# See https://github.com/facebookresearch/xformers/issues/1317#issuecomment-3199392579 # noqa: E501
283+
return AttentionBackendEnum.TORCH_SDPA
286284
else:
287-
# Fallback for Volta/Turing GPUs or FA not supported
288285
return AttentionBackendEnum.XFORMERS
289286

290287
@classmethod

vllm/v1/attention/backends/flash_attn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,8 @@ def get_fp8_dtype_for_flashattn(kv_cache_dtype: str) -> torch.dtype:
119119
raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")
120120

121121
@classmethod
122-
def get_supported_head_sizes(cls) -> list[int]:
123-
return [32, 64, 96, 128, 160, 192, 224, 256]
122+
def supports_head_size(cls, head_size: int) -> bool:
123+
return head_size % 8 == 0 and head_size <= 256
124124

125125
@classmethod
126126
def supports_kv_cache_dtype(cls, kv_cache_dtype: CacheDType | None) -> bool:

0 commit comments

Comments
 (0)