diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 62ac38751aa0..da5a62617129 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -56,53 +56,28 @@ logger = init_logger(__name__) -def check_upstream_fa_availability(dtype: torch.dtype): - if ( - dtype in (torch.float16, torch.bfloat16) - and current_platform.is_cuda() - and current_platform.has_device_capability(80) - ): - from transformers.utils import is_flash_attn_2_available - - return is_flash_attn_2_available() - if current_platform.is_rocm(): - from importlib.util import find_spec - - return find_spec("flash_attn") is not None - return False - - def maybe_get_vit_flash_attn_backend( attn_backend: AttentionBackendEnum, - use_upstream_fa: bool, attn_backend_override: AttentionBackendEnum | None = None, ) -> tuple[AttentionBackendEnum, Callable | None]: if current_platform.is_rocm(): if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9(): attn_backend = AttentionBackendEnum.ROCM_AITER_FA - elif ( - check_upstream_fa_availability(torch.get_default_dtype()) + attn_backend_override is None and on_gfx9() - and attn_backend_override is None + and attn_backend == AttentionBackendEnum.FLASH_ATTN ): - attn_backend = AttentionBackendEnum.FLASH_ATTN - use_upstream_fa = True + pass else: return AttentionBackendEnum.TORCH_SDPA, None - elif current_platform.is_cuda(): - if ( - attn_backend != AttentionBackendEnum.FLASH_ATTN - and check_upstream_fa_availability(torch.get_default_dtype()) - ): - attn_backend = AttentionBackendEnum.FLASH_ATTN - use_upstream_fa = True + pass elif current_platform.is_xpu(): assert attn_backend == AttentionBackendEnum.FLASH_ATTN, ( "XPU platform only supports FLASH_ATTN as vision attention backend." ) - use_upstream_fa = False + pass else: return AttentionBackendEnum.TORCH_SDPA, None @@ -113,10 +88,7 @@ def maybe_get_vit_flash_attn_backend( if attn_backend == AttentionBackendEnum.ROCM_AITER_FA: from aiter import flash_attn_varlen_func else: - if use_upstream_fa: - from flash_attn import flash_attn_varlen_func - else: - from vllm.attention.utils.fa_utils import flash_attn_varlen_func + from vllm.attention.utils.fa_utils import flash_attn_varlen_func else: flash_attn_varlen_func = None @@ -501,11 +473,6 @@ def __init__( attn_backend_override=attn_backend_override, ) - # Some auto-selected backends can be upgraded - # to upstream flash attention if available. - # If vllm native fa is selected, we use it directly. - use_upstream_fa = False - self.attn_backend = ( backend if backend @@ -521,7 +488,6 @@ def __init__( self.attn_backend, self._flash_attn_varlen_func = ( maybe_get_vit_flash_attn_backend( self.attn_backend, - use_upstream_fa, attn_backend_override=attn_backend_override, ) ) @@ -531,17 +497,8 @@ def __init__( AttentionBackendEnum.ROCM_AITER_FA, } - # this condition is just to make sure that the - # use_upstream_fa in the log is correct - if ( - current_platform.is_rocm() - and self.attn_backend == AttentionBackendEnum.FLASH_ATTN - ): - use_upstream_fa = True - logger.info_once( - f"MultiHeadAttention attn_backend: {self.attn_backend}, " - f"use_upstream_fa: {use_upstream_fa}" + f"Using {self.attn_backend} for MultiHeadAttention in multimodal encoder." ) def forward( diff --git a/vllm/attention/ops/vit_attn_wrappers.py b/vllm/attention/ops/vit_attn_wrappers.py index 46f8f5117f7a..d9f15f1e4285 100644 --- a/vllm/attention/ops/vit_attn_wrappers.py +++ b/vllm/attention/ops/vit_attn_wrappers.py @@ -27,15 +27,11 @@ def flash_attn_maxseqlen_wrapper( max_seqlen: torch.Tensor, batch_size: int, is_rocm_aiter: bool, - use_upstream_fa: bool, ) -> torch.Tensor: if is_rocm_aiter: from aiter import flash_attn_varlen_func else: - if use_upstream_fa: - from flash_attn import flash_attn_varlen_func - else: - from vllm.attention.utils.fa_utils import flash_attn_varlen_func + from vllm.attention.utils.fa_utils import flash_attn_varlen_func q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) output = flash_attn_varlen_func( q, @@ -62,7 +58,6 @@ def flash_attn_maxseqlen_wrapper_fake( max_seqlen: torch.Tensor, batch_size: int, is_rocm_aiter: bool, - use_upstream_fa: bool, ) -> torch.Tensor: b, s, h, d = q.shape return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device) @@ -83,10 +78,9 @@ def vit_flash_attn_wrapper( max_seqlen: torch.Tensor, batch_size: int, is_rocm_aiter: bool, - use_upstream_fa: bool, ) -> torch.Tensor: return torch.ops.vllm.flash_attn_maxseqlen_wrapper( - q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter, use_upstream_fa + q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter ) diff --git a/vllm/attention/utils/fa_utils.py b/vllm/attention/utils/fa_utils.py index adb9b08a6573..8a46587473e4 100644 --- a/vllm/attention/utils/fa_utils.py +++ b/vllm/attention/utils/fa_utils.py @@ -18,6 +18,14 @@ reshape_and_cache_flash = ops.reshape_and_cache_flash flash_attn_varlen_func = ops.flash_attn_varlen_func get_scheduler_metadata = ops.get_scheduler_metadata +elif current_platform.is_rocm(): + try: + from flash_attn import flash_attn_varlen_func # noqa: F401 + except ImportError as e: + raise ImportError( + "Rocm platform requires upstream flash-attn " + "to be installed. Please install flash-attn first." + ) from e def get_flash_attn_version(requires_alibi: bool = False) -> int | None: diff --git a/vllm/model_executor/models/dots_ocr.py b/vllm/model_executor/models/dots_ocr.py index 5460018d0d67..5cc2a48f26d6 100644 --- a/vllm/model_executor/models/dots_ocr.py +++ b/vllm/model_executor/models/dots_ocr.py @@ -11,7 +11,6 @@ from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.layer import ( - check_upstream_fa_availability, maybe_get_vit_flash_attn_backend, ) from vllm.config import VllmConfig @@ -294,12 +293,10 @@ def __init__( torch.get_default_dtype(), attn_backend_override=attn_backend_override, ) - self.use_upstream_fa = False self.attn_backend, self.flash_attn_varlen_func = ( maybe_get_vit_flash_attn_backend( self.attn_backend, - self.use_upstream_fa, attn_backend_override=attn_backend_override, ) ) @@ -569,11 +566,6 @@ def __init__( dtype=torch.get_default_dtype(), attn_backend_override=attn_backend_override, ) - if ( - self.attn_backend != AttentionBackendEnum.FLASH_ATTN - and check_upstream_fa_availability(torch.get_default_dtype()) - ): - self.attn_backend = AttentionBackendEnum.FLASH_ATTN self.out_hidden_size = config.hidden_size # Keep blocks for compatibility with other vision towers num_layers = ( diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py index 07b34fbc8add..81663dd7bbb4 100644 --- a/vllm/model_executor/models/ernie45_vl.py +++ b/vllm/model_executor/models/ernie45_vl.py @@ -38,7 +38,6 @@ from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.layer import ( - check_upstream_fa_availability, maybe_get_vit_flash_attn_backend, ) from vllm.config import VllmConfig @@ -201,12 +200,9 @@ def __init__( attn_backend_override=attn_backend_override, ) - self.use_upstream_fa = False - self.attn_backend, self.flash_attn_varlen_func = ( maybe_get_vit_flash_attn_backend( self.attn_backend, - self.use_upstream_fa, attn_backend_override=attn_backend_override, ) ) @@ -498,11 +494,6 @@ def __init__( dtype=torch.get_default_dtype(), attn_backend_override=attn_backend_override, ) - if ( - self.attn_backend != AttentionBackendEnum.FLASH_ATTN - and check_upstream_fa_availability(torch.get_default_dtype()) - ): - self.attn_backend = AttentionBackendEnum.FLASH_ATTN @property def dtype(self) -> torch.dtype: diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index 7e0370886884..fe238861ecce 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -47,10 +47,7 @@ from transformers.video_utils import VideoMetadata from vllm.attention.backends.registry import AttentionBackendEnum -from vllm.attention.layer import ( - check_upstream_fa_availability, - maybe_get_vit_flash_attn_backend, -) +from vllm.attention.layer import maybe_get_vit_flash_attn_backend from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions from vllm.distributed import get_tensor_model_parallel_world_size, parallel_state @@ -296,12 +293,10 @@ def __init__( dtype=torch.get_default_dtype(), attn_backend_override=attn_backend_override, ) - self.use_upstream_fa = False self.attn_backend, self.flash_attn_varlen_func = ( maybe_get_vit_flash_attn_backend( self.attn_backend, - self.use_upstream_fa, attn_backend_override=attn_backend_override, ) ) @@ -730,11 +725,6 @@ def __init__( dtype=torch.get_default_dtype(), attn_backend_override=attn_backend_override, ) - if ( - self.attn_backend != AttentionBackendEnum.FLASH_ATTN - and check_upstream_fa_availability(torch.get_default_dtype()) - ): - self.attn_backend = AttentionBackendEnum.FLASH_ATTN @property def dtype(self) -> torch.dtype: diff --git a/vllm/model_executor/models/keye.py b/vllm/model_executor/models/keye.py index 302260b95299..881760155814 100644 --- a/vllm/model_executor/models/keye.py +++ b/vllm/model_executor/models/keye.py @@ -418,7 +418,6 @@ def __init__( self.attn_backend, self.flash_attn_varlen_func = ( maybe_get_vit_flash_attn_backend( self.attn_backend, - use_upstream_fa=False, attn_backend_override=attn_backend_override, ) ) diff --git a/vllm/model_executor/models/paddleocr_vl.py b/vllm/model_executor/models/paddleocr_vl.py index 74bb868492da..5256d8ba7fd8 100644 --- a/vllm/model_executor/models/paddleocr_vl.py +++ b/vllm/model_executor/models/paddleocr_vl.py @@ -33,7 +33,6 @@ from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.layer import ( - check_upstream_fa_availability, maybe_get_vit_flash_attn_backend, ) from vllm.attention.ops.vit_attn_wrappers import ( @@ -582,7 +581,6 @@ def __init__( prefix: str = "", attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA, attn_backend_override: AttentionBackendEnum | None = None, - use_upstream_fa: bool = False, ) -> None: super().__init__() @@ -612,11 +610,9 @@ def __init__( ) self.attn_backend = attn_backend - self.use_upstream_fa = use_upstream_fa self.attn_backend, self.flash_attn_varlen_func = ( maybe_get_vit_flash_attn_backend( self.attn_backend, - self.use_upstream_fa, attn_backend_override=attn_backend_override, ) ) @@ -680,7 +676,6 @@ def forward( max_seqlen, batch_size, self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA, - self.use_upstream_fa, ) elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: outputs = [] @@ -783,7 +778,6 @@ def __init__( *, attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA, attn_backend_override: AttentionBackendEnum | None = None, - use_upstream_fa: bool = False, ): super().__init__() self.embed_dim = config.hidden_size @@ -796,7 +790,6 @@ def __init__( prefix=f"{prefix}.self_attn", attn_backend=attn_backend, attn_backend_override=attn_backend_override, - use_upstream_fa=use_upstream_fa, ) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = SiglipMLP( @@ -852,13 +845,6 @@ def __init__( dtype=torch.get_default_dtype(), attn_backend_override=attn_backend_override, ) - self.use_upstream_fa = False - if self.attn_backend not in { - AttentionBackendEnum.FLASH_ATTN, - AttentionBackendEnum.ROCM_AITER_FA, - } and check_upstream_fa_availability(torch.get_default_dtype()): - self.attn_backend = AttentionBackendEnum.FLASH_ATTN - self.use_upstream_fa = True if self.attn_backend not in { AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.TORCH_SDPA, @@ -875,7 +861,6 @@ def __init__( prefix=f"{prefix}.layers.{layer_idx}", attn_backend=self.attn_backend, attn_backend_override=attn_backend_override, - use_upstream_fa=self.use_upstream_fa, ) for layer_idx in range(config.num_hidden_layers) ] diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 8c707c2561af..6ca490f46763 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -307,7 +307,6 @@ def __init__( prefix: str = "", use_data_parallel: bool = False, attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA, - use_upstream_fa: bool = False, attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() @@ -344,24 +343,13 @@ def __init__( disable_tp=use_data_parallel, ) self.attn_backend = attn_backend - self.use_upstream_fa = use_upstream_fa self.attn_backend, self.flash_attn_varlen_func = ( maybe_get_vit_flash_attn_backend( self.attn_backend, - self.use_upstream_fa, attn_backend_override=attn_backend_override, ) ) - # On ROCm with FLASH_ATTN backend, upstream flash_attn is used - from vllm.platforms import current_platform - if ( - current_platform.is_rocm() - and self.attn_backend == AttentionBackendEnum.FLASH_ATTN - ): - self.use_upstream_fa = True - if current_platform.is_xpu(): - self.use_upstream_fa = False self.is_flash_attn_backend = self.attn_backend in { AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.ROCM_AITER_FA, @@ -415,7 +403,6 @@ def forward( max_seqlen, batch_size, self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA, - self.use_upstream_fa, ) elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: # Execute attention entry by entry for speed & less VRAM. @@ -459,7 +446,6 @@ def __init__( prefix: str = "", use_data_parallel: bool = False, attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA, - use_upstream_fa: bool = False, attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() @@ -475,7 +461,6 @@ def __init__( prefix=f"{prefix}.attn", use_data_parallel=use_data_parallel, attn_backend=attn_backend, - use_upstream_fa=use_upstream_fa, attn_backend_override=attn_backend_override, ) self.mlp = Qwen2_5_VisionMLP( @@ -644,7 +629,6 @@ def __init__( is_neox_style=True, ) - use_upstream_fa = False self.attn_backend = get_vit_attn_backend( head_size=head_dim, dtype=torch.get_default_dtype(), @@ -654,7 +638,6 @@ def __init__( self.attn_backend, self.flash_attn_varlen_func = ( maybe_get_vit_flash_attn_backend( self.attn_backend, - use_upstream_fa, attn_backend_override=attn_backend_override, ) ) @@ -681,7 +664,6 @@ def __init__( prefix=f"{prefix}.blocks.{layer_idx}", use_data_parallel=use_data_parallel, attn_backend=self.attn_backend, - use_upstream_fa=use_upstream_fa, attn_backend_override=attn_backend_override, ) for layer_idx in range(depth) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 9d1d023aed17..672659aa6042 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -45,7 +45,6 @@ from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.layer import ( - check_upstream_fa_availability, maybe_get_vit_flash_attn_backend, ) from vllm.config import VllmConfig @@ -335,12 +334,10 @@ def __init__( dtype=torch.get_default_dtype(), attn_backend_override=attn_backend_override, ) - self.use_upstream_fa = False self.attn_backend, self.flash_attn_varlen_func = ( maybe_get_vit_flash_attn_backend( self.attn_backend, - self.use_upstream_fa, attn_backend_override=attn_backend_override, ) ) @@ -657,11 +654,6 @@ def __init__( dtype=torch.get_default_dtype(), attn_backend_override=attn_backend_override, ) - if ( - self.attn_backend != AttentionBackendEnum.FLASH_ATTN - and check_upstream_fa_availability(torch.get_default_dtype()) - ): - self.attn_backend = AttentionBackendEnum.FLASH_ATTN @property def dtype(self) -> torch.dtype: diff --git a/vllm/model_executor/models/qwen3_omni_moe_thinker.py b/vllm/model_executor/models/qwen3_omni_moe_thinker.py index f5f88f66eff9..39dd42552ae8 100755 --- a/vllm/model_executor/models/qwen3_omni_moe_thinker.py +++ b/vllm/model_executor/models/qwen3_omni_moe_thinker.py @@ -47,7 +47,6 @@ from transformers.models.whisper import WhisperFeatureExtractor from vllm.attention.backends.registry import AttentionBackendEnum -from vllm.attention.layer import check_upstream_fa_availability from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig from vllm.distributed import get_pp_group @@ -381,11 +380,6 @@ def __init__( dtype=torch.get_default_dtype(), attn_backend_override=attn_backend_override, ) - if ( - self.attn_backend != AttentionBackendEnum.FLASH_ATTN - and check_upstream_fa_availability(torch.get_default_dtype()) - ): - self.attn_backend = AttentionBackendEnum.FLASH_ATTN @property def dtype(self) -> torch.dtype: diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 4cd6fa14c32d..fa8698795245 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -49,7 +49,6 @@ from transformers.video_utils import VideoMetadata from vllm.attention.backends.registry import AttentionBackendEnum -from vllm.attention.layer import check_upstream_fa_availability from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions @@ -201,7 +200,6 @@ def __init__( prefix: str = "", use_data_parallel: bool = False, attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA, - use_upstream_fa: bool = False, ) -> None: super().__init__() if norm_layer is None: @@ -216,7 +214,6 @@ def __init__( prefix=f"{prefix}.attn", use_data_parallel=use_data_parallel, attn_backend=attn_backend, - use_upstream_fa=use_upstream_fa, ) self.mlp = Qwen3_VisionMLP( dim, @@ -377,14 +374,6 @@ def __init__( dtype=torch.get_default_dtype(), attn_backend_override=attn_backend_override, ) - use_upstream_fa = False - if ( - self.attn_backend != AttentionBackendEnum.FLASH_ATTN - and self.attn_backend != AttentionBackendEnum.ROCM_AITER_FA - and check_upstream_fa_availability(torch.get_default_dtype()) - ): - self.attn_backend = AttentionBackendEnum.FLASH_ATTN - use_upstream_fa = True if self.attn_backend not in { AttentionBackendEnum.FLASH_ATTN, @@ -406,7 +395,6 @@ def __init__( prefix=f"{prefix}.blocks.{layer_idx}", use_data_parallel=use_data_parallel, attn_backend=self.attn_backend, - use_upstream_fa=use_upstream_fa, ) for layer_idx in range(vision_config.depth) ] diff --git a/vllm/model_executor/models/siglip2navit.py b/vllm/model_executor/models/siglip2navit.py index c185b45345bd..bbce01995412 100644 --- a/vllm/model_executor/models/siglip2navit.py +++ b/vllm/model_executor/models/siglip2navit.py @@ -255,12 +255,10 @@ def __init__( dtype=torch.get_default_dtype(), attn_backend_override=attn_backend_override, ) - self.use_upstream_fa = False self.attn_backend, self.flash_attn_varlen_func = ( maybe_get_vit_flash_attn_backend( self.attn_backend, - self.use_upstream_fa, attn_backend_override=attn_backend_override, ) )