-
-
Notifications
You must be signed in to change notification settings - Fork 11.7k
Remove upstream fa checks #29471
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove upstream fa checks #29471
Changes from all commits
0b01516
ae44624
2eaf2a4
2bf56d3
0e89015
10c34ad
8dbf627
e2a5e5c
712da16
7e973ff
966854b
0fe954a
db4de68
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -56,53 +56,28 @@ | |
| logger = init_logger(__name__) | ||
|
|
||
|
|
||
| def check_upstream_fa_availability(dtype: torch.dtype): | ||
| if ( | ||
| dtype in (torch.float16, torch.bfloat16) | ||
| and current_platform.is_cuda() | ||
| and current_platform.has_device_capability(80) | ||
| ): | ||
| from transformers.utils import is_flash_attn_2_available | ||
|
|
||
| return is_flash_attn_2_available() | ||
| if current_platform.is_rocm(): | ||
| from importlib.util import find_spec | ||
|
|
||
| return find_spec("flash_attn") is not None | ||
| return False | ||
|
|
||
|
|
||
| def maybe_get_vit_flash_attn_backend( | ||
| attn_backend: AttentionBackendEnum, | ||
| use_upstream_fa: bool, | ||
| attn_backend_override: AttentionBackendEnum | None = None, | ||
| ) -> tuple[AttentionBackendEnum, Callable | None]: | ||
| if current_platform.is_rocm(): | ||
| if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9(): | ||
| attn_backend = AttentionBackendEnum.ROCM_AITER_FA | ||
|
|
||
| elif ( | ||
| check_upstream_fa_availability(torch.get_default_dtype()) | ||
| attn_backend_override is None | ||
| and on_gfx9() | ||
| and attn_backend_override is None | ||
| and attn_backend == AttentionBackendEnum.FLASH_ATTN | ||
| ): | ||
| attn_backend = AttentionBackendEnum.FLASH_ATTN | ||
| use_upstream_fa = True | ||
| pass | ||
| else: | ||
| return AttentionBackendEnum.TORCH_SDPA, None | ||
|
|
||
| elif current_platform.is_cuda(): | ||
| if ( | ||
| attn_backend != AttentionBackendEnum.FLASH_ATTN | ||
| and check_upstream_fa_availability(torch.get_default_dtype()) | ||
| ): | ||
| attn_backend = AttentionBackendEnum.FLASH_ATTN | ||
| use_upstream_fa = True | ||
| pass | ||
| elif current_platform.is_xpu(): | ||
| assert attn_backend == AttentionBackendEnum.FLASH_ATTN, ( | ||
| "XPU platform only supports FLASH_ATTN as vision attention backend." | ||
| ) | ||
| use_upstream_fa = False | ||
| pass | ||
| else: | ||
| return AttentionBackendEnum.TORCH_SDPA, None | ||
|
|
||
|
|
@@ -113,10 +88,7 @@ def maybe_get_vit_flash_attn_backend( | |
| if attn_backend == AttentionBackendEnum.ROCM_AITER_FA: | ||
| from aiter import flash_attn_varlen_func | ||
| else: | ||
| if use_upstream_fa: | ||
| from flash_attn import flash_attn_varlen_func | ||
| else: | ||
| from vllm.attention.utils.fa_utils import flash_attn_varlen_func | ||
| from vllm.attention.utils.fa_utils import flash_attn_varlen_func | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
We can keep the import statement
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added this import to fa_utils as it looks like the most simple way of it. And except message tells user to install upstream fa when import error is raised. Please check if that works, thanks! |
||
| else: | ||
| flash_attn_varlen_func = None | ||
|
|
||
|
|
@@ -501,11 +473,6 @@ def __init__( | |
| attn_backend_override=attn_backend_override, | ||
| ) | ||
|
|
||
| # Some auto-selected backends can be upgraded | ||
| # to upstream flash attention if available. | ||
| # If vllm native fa is selected, we use it directly. | ||
| use_upstream_fa = False | ||
|
|
||
| self.attn_backend = ( | ||
| backend | ||
| if backend | ||
|
|
@@ -521,7 +488,6 @@ def __init__( | |
| self.attn_backend, self._flash_attn_varlen_func = ( | ||
| maybe_get_vit_flash_attn_backend( | ||
| self.attn_backend, | ||
| use_upstream_fa, | ||
| attn_backend_override=attn_backend_override, | ||
| ) | ||
| ) | ||
|
|
@@ -531,17 +497,8 @@ def __init__( | |
| AttentionBackendEnum.ROCM_AITER_FA, | ||
| } | ||
|
|
||
| # this condition is just to make sure that the | ||
| # use_upstream_fa in the log is correct | ||
| if ( | ||
| current_platform.is_rocm() | ||
| and self.attn_backend == AttentionBackendEnum.FLASH_ATTN | ||
| ): | ||
| use_upstream_fa = True | ||
|
|
||
| logger.info_once( | ||
| f"MultiHeadAttention attn_backend: {self.attn_backend}, " | ||
| f"use_upstream_fa: {use_upstream_fa}" | ||
| f"Using {self.attn_backend} for MultiHeadAttention in multimodal encoder." | ||
| ) | ||
|
|
||
| def forward( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,15 +27,11 @@ def flash_attn_maxseqlen_wrapper( | |
| max_seqlen: torch.Tensor, | ||
| batch_size: int, | ||
| is_rocm_aiter: bool, | ||
| use_upstream_fa: bool, | ||
| ) -> torch.Tensor: | ||
| if is_rocm_aiter: | ||
| from aiter import flash_attn_varlen_func | ||
| else: | ||
| if use_upstream_fa: | ||
| from flash_attn import flash_attn_varlen_func | ||
| else: | ||
| from vllm.attention.utils.fa_utils import flash_attn_varlen_func | ||
| from vllm.attention.utils.fa_utils import flash_attn_varlen_func | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. like wise,
We can keep the import statement |
||
| q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) | ||
| output = flash_attn_varlen_func( | ||
| q, | ||
|
|
@@ -62,7 +58,6 @@ def flash_attn_maxseqlen_wrapper_fake( | |
| max_seqlen: torch.Tensor, | ||
| batch_size: int, | ||
| is_rocm_aiter: bool, | ||
| use_upstream_fa: bool, | ||
| ) -> torch.Tensor: | ||
| b, s, h, d = q.shape | ||
| return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device) | ||
|
|
@@ -83,10 +78,9 @@ def vit_flash_attn_wrapper( | |
| max_seqlen: torch.Tensor, | ||
| batch_size: int, | ||
| is_rocm_aiter: bool, | ||
| use_upstream_fa: bool, | ||
| ) -> torch.Tensor: | ||
| return torch.ops.vllm.flash_attn_maxseqlen_wrapper( | ||
| q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter, use_upstream_fa | ||
| q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter | ||
| ) | ||
|
|
||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe_get_vit_flash_attn_backendnow returnsTORCH_SDPAfor any platform that is neither ROCm nor XPU. CUDA is caught by thiselsebranch, so even whenget_vit_attn_backendselectsFLASH_ATTNthe function forces Torch SDPA and never loads the flash attention kernel. This effectively disables flash attention for all vision models on CUDA, degrading the intended fast path everywhere and ignoring user overrides.Useful? React with 👍 / 👎.