From 76b70ce618dd610c32fa6bc737eed4618201ee0c Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Sat, 15 Nov 2025 00:15:30 +0000 Subject: [PATCH 1/9] update head size support Signed-off-by: Matthew Bonanni --- vllm/v1/attention/backends/flash_attn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index a5d4435000d4..fdc99a0df1c8 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -119,8 +119,8 @@ def get_fp8_dtype_for_flashattn(kv_cache_dtype: str) -> torch.dtype: raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}") @classmethod - def get_supported_head_sizes(cls) -> list[int]: - return [32, 64, 96, 128, 160, 192, 224, 256] + def supports_head_size(cls, head_size: int) -> bool: + return head_size % 8 == 0 and head_size <= 256 @classmethod def supports_kv_cache_dtype(cls, kv_cache_dtype: CacheDType | None) -> bool: From 0001b995fda59101b71d7ac3a2821b1ce2f01473 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Sat, 15 Nov 2025 00:15:49 +0000 Subject: [PATCH 2/9] update test Signed-off-by: Matthew Bonanni --- tests/kernels/attention/test_flash_attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/kernels/attention/test_flash_attn.py b/tests/kernels/attention/test_flash_attn.py index 6e5468969bf2..ac7cf8ace9a9 100644 --- a/tests/kernels/attention/test_flash_attn.py +++ b/tests/kernels/attention/test_flash_attn.py @@ -13,7 +13,7 @@ ) NUM_HEADS = [(4, 4), (8, 2)] -HEAD_SIZES = [128, 256] +HEAD_SIZES = [40, 72, 80, 128, 256] BLOCK_SIZES = [16] DTYPES = [torch.bfloat16] QDTYPES = [None, torch.float8_e4m3fn] From 659024411dac3ce112bfaa246d7d309552c94836 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Sat, 15 Nov 2025 00:26:18 +0000 Subject: [PATCH 3/9] update selector Signed-off-by: Matthew Bonanni --- vllm/platforms/cuda.py | 24 +++++++++--------------- 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 2e4dd8bb808b..62b601518bb8 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -267,24 +267,18 @@ def get_vit_attn_backend( ) -> "AttentionBackendEnum": from vllm.attention.backends.registry import AttentionBackendEnum - # For Blackwell GPUs, force TORCH_SDPA for now. - # See https://github.com/facebookresearch/xformers/issues/1317#issuecomment-3199392579 # noqa: E501 + # Try FlashAttention first + backend_class = AttentionBackendEnum.FLASH_ATTN.get_class() + if backend_class.supports_head_size(head_size) and backend_class.supports_dtype( + dtype + ): + return AttentionBackendEnum.FLASH_ATTN + if cls.has_device_capability(100): + # xFormers doesn't support Blackwell, fall back to SDPA + # See https://github.com/facebookresearch/xformers/issues/1317#issuecomment-3199392579 # noqa: E501 return AttentionBackendEnum.TORCH_SDPA - - if dtype not in (torch.float16, torch.bfloat16): - return AttentionBackendEnum.XFORMERS - - if cls.has_device_capability(80): - backend_class = AttentionBackendEnum.FLASH_ATTN.get_class() - if backend_class.supports_head_size( - head_size - ) and backend_class.supports_dtype(dtype): - return AttentionBackendEnum.FLASH_ATTN - else: - return AttentionBackendEnum.XFORMERS else: - # Fallback for Volta/Turing GPUs or FA not supported return AttentionBackendEnum.XFORMERS @classmethod From 9e96e6fb621595f161fac6946da8c6dcb55ec170 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Sat, 15 Nov 2025 00:28:47 +0000 Subject: [PATCH 4/9] fix capability Signed-off-by: Matthew Bonanni --- vllm/platforms/cuda.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 62b601518bb8..844d6a6f31b5 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -269,8 +269,10 @@ def get_vit_attn_backend( # Try FlashAttention first backend_class = AttentionBackendEnum.FLASH_ATTN.get_class() - if backend_class.supports_head_size(head_size) and backend_class.supports_dtype( - dtype + if ( + backend_class.supports_head_size(head_size) + and backend_class.supports_dtype(dtype) + and cls.has_device_capability(80) ): return AttentionBackendEnum.FLASH_ATTN From 1d35c460cd227c599a40a83ae37cdd42ae6da5b6 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Sat, 15 Nov 2025 00:30:09 +0000 Subject: [PATCH 5/9] temporary FA hash Signed-off-by: Matthew Bonanni --- cmake/external_projects/vllm_flash_attn.cmake | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cmake/external_projects/vllm_flash_attn.cmake b/cmake/external_projects/vllm_flash_attn.cmake index 567c8959f045..f8904b8655dd 100644 --- a/cmake/external_projects/vllm_flash_attn.cmake +++ b/cmake/external_projects/vllm_flash_attn.cmake @@ -37,8 +37,8 @@ if(VLLM_FLASH_ATTN_SRC_DIR) else() FetchContent_Declare( vllm-flash-attn - GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git - GIT_TAG 58e0626a692f09241182582659e3bf8f16472659 + GIT_REPOSITORY https://github.com/MatthewBonanni/flash-attention.git + GIT_TAG ec19a62162ccbd94301f9f32ea79504b6274abeb GIT_PROGRESS TRUE # Don't share the vllm-flash-attn build between build types BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn From c5d997d2e1b6b8b5fa24ff68c47397a3ad11d5aa Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Sat, 15 Nov 2025 00:34:21 +0000 Subject: [PATCH 6/9] current FA build doesn't support tanh softcapping Signed-off-by: Matthew Bonanni --- tests/kernels/attention/test_flash_attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/kernels/attention/test_flash_attn.py b/tests/kernels/attention/test_flash_attn.py index ac7cf8ace9a9..26b8c77ab482 100644 --- a/tests/kernels/attention/test_flash_attn.py +++ b/tests/kernels/attention/test_flash_attn.py @@ -20,7 +20,7 @@ # one value large enough to test overflow in index calculation. # one value small enough to test the schema op check NUM_BLOCKS = [32768, 2048] -SOFT_CAPS = [None, 50.0] +SOFT_CAPS = [None] SLIDING_WINDOWS = [None, 256] From 0c871f456bb4c102f8768ce1d1b55c8741fcab8e Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Sat, 15 Nov 2025 00:41:11 +0000 Subject: [PATCH 7/9] handle import error Signed-off-by: Matthew Bonanni --- vllm/platforms/cuda.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 844d6a6f31b5..f9bf242b7194 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -268,13 +268,14 @@ def get_vit_attn_backend( from vllm.attention.backends.registry import AttentionBackendEnum # Try FlashAttention first - backend_class = AttentionBackendEnum.FLASH_ATTN.get_class() - if ( - backend_class.supports_head_size(head_size) - and backend_class.supports_dtype(dtype) - and cls.has_device_capability(80) - ): - return AttentionBackendEnum.FLASH_ATTN + try: + backend_class = AttentionBackendEnum.FLASH_ATTN.get_class() + if backend_class.supports_head_size( + head_size + ) and backend_class.supports_dtype(dtype): + return AttentionBackendEnum.FLASH_ATTN + except ImportError: + pass if cls.has_device_capability(100): # xFormers doesn't support Blackwell, fall back to SDPA From 25c83d4f9a7f49d6ee8dfc9ae21524c1a315454a Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Mon, 17 Nov 2025 15:33:06 +0000 Subject: [PATCH 8/9] update test Signed-off-by: Matthew Bonanni --- tests/kernels/attention/test_mha_attn.py | 30 +----------------------- 1 file changed, 1 insertion(+), 29 deletions(-) diff --git a/tests/kernels/attention/test_mha_attn.py b/tests/kernels/attention/test_mha_attn.py index 183bbf3bf4e0..a878ac6396ce 100644 --- a/tests/kernels/attention/test_mha_attn.py +++ b/tests/kernels/attention/test_mha_attn.py @@ -62,38 +62,10 @@ def test_mha_attn_platform(device: str): assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN # Test CUDA with head_size=72 (not divisible by 32) - # - with upstream FA not available - # - should use xformers - with ( - patch("vllm.attention.layer.current_platform", CudaPlatform()), - patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()), - patch( - "vllm.attention.layer.check_upstream_fa_availability", - return_value=False, - ), - ): - attn = MultiHeadAttention(16, 72, scale=1) - assert attn.attn_backend == AttentionBackendEnum.XFORMERS - - # Test CUDA with head_size=72 (not divisible by 32) - # - with upstream FA available - # - should use upstream FA + # - should use vLLM's FlashAttention with ( patch("vllm.attention.layer.current_platform", CudaPlatform()), patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()), - patch( - "vllm.attention.layer.check_upstream_fa_availability", return_value=True - ), - patch.dict( - "sys.modules", - { - "flash_attn": type( - "MockFlashAttn", - (), - {"flash_attn_varlen_func": lambda *args, **kwargs: None}, - )() - }, - ), ): attn = MultiHeadAttention(16, 72, scale=1) assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN From c4edbac3bc9496ea7e186c7ea45abb86416ece3d Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Mon, 17 Nov 2025 18:22:15 +0000 Subject: [PATCH 9/9] update tag post-merge Signed-off-by: Matthew Bonanni --- cmake/external_projects/vllm_flash_attn.cmake | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cmake/external_projects/vllm_flash_attn.cmake b/cmake/external_projects/vllm_flash_attn.cmake index f8904b8655dd..6cc5cda14c52 100644 --- a/cmake/external_projects/vllm_flash_attn.cmake +++ b/cmake/external_projects/vllm_flash_attn.cmake @@ -37,8 +37,8 @@ if(VLLM_FLASH_ATTN_SRC_DIR) else() FetchContent_Declare( vllm-flash-attn - GIT_REPOSITORY https://github.com/MatthewBonanni/flash-attention.git - GIT_TAG ec19a62162ccbd94301f9f32ea79504b6274abeb + GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git + GIT_TAG 71bb26f6295449be880344b93b51791cc009237d GIT_PROGRESS TRUE # Don't share the vllm-flash-attn build between build types BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn