5656logger = init_logger (__name__ )
5757
5858
59- def check_upstream_fa_availability (dtype : torch .dtype ):
60- if (
61- dtype in (torch .float16 , torch .bfloat16 )
62- and current_platform .is_cuda ()
63- and current_platform .has_device_capability (80 )
64- ):
65- from transformers .utils import is_flash_attn_2_available
66-
67- return is_flash_attn_2_available ()
68- if current_platform .is_rocm ():
69- from importlib .util import find_spec
70-
71- return find_spec ("flash_attn" ) is not None
72- return False
73-
74-
7559def maybe_get_vit_flash_attn_backend (
7660 attn_backend : AttentionBackendEnum ,
77- use_upstream_fa : bool ,
7861 attn_backend_override : AttentionBackendEnum | None = None ,
7962) -> tuple [AttentionBackendEnum , Callable | None ]:
8063 if current_platform .is_rocm ():
8164 if envs .VLLM_ROCM_USE_AITER and envs .VLLM_ROCM_USE_AITER_MHA and on_gfx9 ():
8265 attn_backend = AttentionBackendEnum .ROCM_AITER_FA
83-
8466 elif (
85- check_upstream_fa_availability ( torch . get_default_dtype ())
67+ attn_backend_override is None
8668 and on_gfx9 ()
87- and attn_backend_override is None
69+ and attn_backend == AttentionBackendEnum . FLASH_ATTN
8870 ):
89- attn_backend = AttentionBackendEnum .FLASH_ATTN
90- use_upstream_fa = True
71+ pass
9172 else :
9273 return AttentionBackendEnum .TORCH_SDPA , None
93-
9474 elif current_platform .is_cuda ():
95- if (
96- attn_backend != AttentionBackendEnum .FLASH_ATTN
97- and check_upstream_fa_availability (torch .get_default_dtype ())
98- ):
99- attn_backend = AttentionBackendEnum .FLASH_ATTN
100- use_upstream_fa = True
75+ pass
10176 elif current_platform .is_xpu ():
10277 assert attn_backend == AttentionBackendEnum .FLASH_ATTN , (
10378 "XPU platform only supports FLASH_ATTN as vision attention backend."
10479 )
105- use_upstream_fa = False
80+ pass
10681 else :
10782 return AttentionBackendEnum .TORCH_SDPA , None
10883
@@ -113,10 +88,7 @@ def maybe_get_vit_flash_attn_backend(
11388 if attn_backend == AttentionBackendEnum .ROCM_AITER_FA :
11489 from aiter import flash_attn_varlen_func
11590 else :
116- if use_upstream_fa :
117- from flash_attn import flash_attn_varlen_func
118- else :
119- from vllm .attention .utils .fa_utils import flash_attn_varlen_func
91+ from vllm .attention .utils .fa_utils import flash_attn_varlen_func
12092 else :
12193 flash_attn_varlen_func = None
12294
@@ -501,11 +473,6 @@ def __init__(
501473 attn_backend_override = attn_backend_override ,
502474 )
503475
504- # Some auto-selected backends can be upgraded
505- # to upstream flash attention if available.
506- # If vllm native fa is selected, we use it directly.
507- use_upstream_fa = False
508-
509476 self .attn_backend = (
510477 backend
511478 if backend
@@ -521,7 +488,6 @@ def __init__(
521488 self .attn_backend , self ._flash_attn_varlen_func = (
522489 maybe_get_vit_flash_attn_backend (
523490 self .attn_backend ,
524- use_upstream_fa ,
525491 attn_backend_override = attn_backend_override ,
526492 )
527493 )
@@ -531,17 +497,8 @@ def __init__(
531497 AttentionBackendEnum .ROCM_AITER_FA ,
532498 }
533499
534- # this condition is just to make sure that the
535- # use_upstream_fa in the log is correct
536- if (
537- current_platform .is_rocm ()
538- and self .attn_backend == AttentionBackendEnum .FLASH_ATTN
539- ):
540- use_upstream_fa = True
541-
542500 logger .info_once (
543- f"MultiHeadAttention attn_backend: { self .attn_backend } , "
544- f"use_upstream_fa: { use_upstream_fa } "
501+ f"Using { self .attn_backend } for MultiHeadAttention in multimodal encoder."
545502 )
546503
547504 def forward (
0 commit comments