diff --git a/cmake/external_projects/vllm_flash_attn.cmake b/cmake/external_projects/vllm_flash_attn.cmake index 567c8959f045..6cc5cda14c52 100644 --- a/cmake/external_projects/vllm_flash_attn.cmake +++ b/cmake/external_projects/vllm_flash_attn.cmake @@ -38,7 +38,7 @@ else() FetchContent_Declare( vllm-flash-attn GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git - GIT_TAG 58e0626a692f09241182582659e3bf8f16472659 + GIT_TAG 71bb26f6295449be880344b93b51791cc009237d GIT_PROGRESS TRUE # Don't share the vllm-flash-attn build between build types BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn diff --git a/tests/kernels/attention/test_flash_attn.py b/tests/kernels/attention/test_flash_attn.py index 6e5468969bf2..26b8c77ab482 100644 --- a/tests/kernels/attention/test_flash_attn.py +++ b/tests/kernels/attention/test_flash_attn.py @@ -13,14 +13,14 @@ ) NUM_HEADS = [(4, 4), (8, 2)] -HEAD_SIZES = [128, 256] +HEAD_SIZES = [40, 72, 80, 128, 256] BLOCK_SIZES = [16] DTYPES = [torch.bfloat16] QDTYPES = [None, torch.float8_e4m3fn] # one value large enough to test overflow in index calculation. # one value small enough to test the schema op check NUM_BLOCKS = [32768, 2048] -SOFT_CAPS = [None, 50.0] +SOFT_CAPS = [None] SLIDING_WINDOWS = [None, 256] diff --git a/tests/kernels/attention/test_mha_attn.py b/tests/kernels/attention/test_mha_attn.py index 183bbf3bf4e0..a878ac6396ce 100644 --- a/tests/kernels/attention/test_mha_attn.py +++ b/tests/kernels/attention/test_mha_attn.py @@ -62,38 +62,10 @@ def test_mha_attn_platform(device: str): assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN # Test CUDA with head_size=72 (not divisible by 32) - # - with upstream FA not available - # - should use xformers - with ( - patch("vllm.attention.layer.current_platform", CudaPlatform()), - patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()), - patch( - "vllm.attention.layer.check_upstream_fa_availability", - return_value=False, - ), - ): - attn = MultiHeadAttention(16, 72, scale=1) - assert attn.attn_backend == AttentionBackendEnum.XFORMERS - - # Test CUDA with head_size=72 (not divisible by 32) - # - with upstream FA available - # - should use upstream FA + # - should use vLLM's FlashAttention with ( patch("vllm.attention.layer.current_platform", CudaPlatform()), patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()), - patch( - "vllm.attention.layer.check_upstream_fa_availability", return_value=True - ), - patch.dict( - "sys.modules", - { - "flash_attn": type( - "MockFlashAttn", - (), - {"flash_attn_varlen_func": lambda *args, **kwargs: None}, - )() - }, - ), ): attn = MultiHeadAttention(16, 72, scale=1) assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 2e4dd8bb808b..f9bf242b7194 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -267,24 +267,21 @@ def get_vit_attn_backend( ) -> "AttentionBackendEnum": from vllm.attention.backends.registry import AttentionBackendEnum - # For Blackwell GPUs, force TORCH_SDPA for now. - # See https://github.com/facebookresearch/xformers/issues/1317#issuecomment-3199392579 # noqa: E501 - if cls.has_device_capability(100): - return AttentionBackendEnum.TORCH_SDPA - - if dtype not in (torch.float16, torch.bfloat16): - return AttentionBackendEnum.XFORMERS - - if cls.has_device_capability(80): + # Try FlashAttention first + try: backend_class = AttentionBackendEnum.FLASH_ATTN.get_class() if backend_class.supports_head_size( head_size ) and backend_class.supports_dtype(dtype): return AttentionBackendEnum.FLASH_ATTN - else: - return AttentionBackendEnum.XFORMERS + except ImportError: + pass + + if cls.has_device_capability(100): + # xFormers doesn't support Blackwell, fall back to SDPA + # See https://github.com/facebookresearch/xformers/issues/1317#issuecomment-3199392579 # noqa: E501 + return AttentionBackendEnum.TORCH_SDPA else: - # Fallback for Volta/Turing GPUs or FA not supported return AttentionBackendEnum.XFORMERS @classmethod diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index a5d4435000d4..fdc99a0df1c8 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -119,8 +119,8 @@ def get_fp8_dtype_for_flashattn(kv_cache_dtype: str) -> torch.dtype: raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}") @classmethod - def get_supported_head_sizes(cls) -> list[int]: - return [32, 64, 96, 128, 160, 192, 224, 256] + def supports_head_size(cls, head_size: int) -> bool: + return head_size % 8 == 0 and head_size <= 256 @classmethod def supports_kv_cache_dtype(cls, kv_cache_dtype: CacheDType | None) -> bool: