Skip to content
51 changes: 1 addition & 50 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,53 +53,20 @@
logger = init_logger(__name__)


def check_upstream_fa_availability(dtype: torch.dtype):
if (
dtype in (torch.float16, torch.bfloat16)
and current_platform.is_cuda()
and current_platform.has_device_capability(80)
):
from transformers.utils import is_flash_attn_2_available

return is_flash_attn_2_available()
if current_platform.is_rocm():
from importlib.util import find_spec

return find_spec("flash_attn") is not None
return False


def maybe_get_vit_flash_attn_backend(
attn_backend: AttentionBackendEnum,
use_upstream_fa: bool,
attn_backend_override: AttentionBackendEnum | None = None,
) -> tuple[AttentionBackendEnum, Callable | None]:
if current_platform.is_rocm():
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
attn_backend = AttentionBackendEnum.ROCM_AITER_FA

elif (
check_upstream_fa_availability(torch.get_default_dtype())
and on_gfx9()
and attn_backend_override is None
):
attn_backend = AttentionBackendEnum.FLASH_ATTN
use_upstream_fa = True
else:
return AttentionBackendEnum.TORCH_SDPA, None

elif current_platform.is_cuda():
if (
attn_backend != AttentionBackendEnum.FLASH_ATTN
and check_upstream_fa_availability(torch.get_default_dtype())
):
attn_backend = AttentionBackendEnum.FLASH_ATTN
use_upstream_fa = True
elif current_platform.is_xpu():
assert attn_backend == AttentionBackendEnum.FLASH_ATTN, (
"XPU platform only supports FLASH_ATTN as vision attention backend."
)
use_upstream_fa = False
else:
return AttentionBackendEnum.TORCH_SDPA, None
Comment on lines 81 to 82

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Restore CUDA path in vit flash-attn selection

maybe_get_vit_flash_attn_backend now returns TORCH_SDPA for any platform that is neither ROCm nor XPU. CUDA is caught by this else branch, so even when get_vit_attn_backend selects FLASH_ATTN the function forces Torch SDPA and never loads the flash attention kernel. This effectively disables flash attention for all vision models on CUDA, degrading the intended fast path everywhere and ignoring user overrides.

Useful? React with 👍 / 👎.


Expand All @@ -110,10 +77,7 @@ def maybe_get_vit_flash_attn_backend(
if attn_backend == AttentionBackendEnum.ROCM_AITER_FA:
from aiter import flash_attn_varlen_func
else:
if use_upstream_fa:
from flash_attn import flash_attn_varlen_func
else:
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
Copy link
Collaborator

@tjtanaa tjtanaa Nov 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vllm/attention/utils/fa_utils.py does not have the logic for ROCm, flash_attn_varlen_func will be a None object if imported this way.

We can keep the import statement from flash_attn import flash_attn_varlen_func for now. Else we have to add this from flash_attn import flash_attn_varlen_func import statement into the vllm/attention/utils/fa_utils.py when platform is rocm.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this import to fa_utils as it looks like the most simple way of it. And except message tells user to install upstream fa when import error is raised. Please check if that works, thanks!

else:
flash_attn_varlen_func = None

Expand Down Expand Up @@ -498,10 +462,6 @@ def __init__(
attn_backend_override=attn_backend_override,
)

# Some auto-selected backends can be upgraded
# to upstream flash attention if available.
# If vllm native fa is selected, we use it directly.
use_upstream_fa = False

self.attn_backend = (
backend
Expand All @@ -518,7 +478,6 @@ def __init__(
self.attn_backend, self._flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
use_upstream_fa,
attn_backend_override=attn_backend_override,
)
)
Expand All @@ -528,17 +487,9 @@ def __init__(
AttentionBackendEnum.ROCM_AITER_FA,
}

# this condition is just to make sure that the
# use_upstream_fa in the log is correct
if (
current_platform.is_rocm()
and self.attn_backend == AttentionBackendEnum.FLASH_ATTN
):
use_upstream_fa = True

logger.info_once(
f"MultiHeadAttention attn_backend: {self.attn_backend}, "
f"use_upstream_fa: {use_upstream_fa}"
)

def forward(
Expand Down
10 changes: 2 additions & 8 deletions vllm/attention/ops/vit_attn_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,11 @@ def flash_attn_maxseqlen_wrapper(
max_seqlen: torch.Tensor,
batch_size: int,
is_rocm_aiter: bool,
use_upstream_fa: bool,
) -> torch.Tensor:
if is_rocm_aiter:
from aiter import flash_attn_varlen_func
else:
if use_upstream_fa:
from flash_attn import flash_attn_varlen_func
else:
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
Copy link
Collaborator

@tjtanaa tjtanaa Nov 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

like wise,

vllm/attention/utils/fa_utils.py does not have the logic for ROCm, flash_attn_varlen_func will be a None object if imported this way.

We can keep the import statement from flash_attn import flash_attn_varlen_func for now. Else we have to add this from flash_attn import flash_attn_varlen_func import statement into the vllm/attention/utils/fa_utils.py when platform is rocm.

q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
output = flash_attn_varlen_func(
q,
Expand All @@ -62,7 +58,6 @@ def flash_attn_maxseqlen_wrapper_fake(
max_seqlen: torch.Tensor,
batch_size: int,
is_rocm_aiter: bool,
use_upstream_fa: bool,
) -> torch.Tensor:
b, s, h, d = q.shape
return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device)
Expand All @@ -83,10 +78,9 @@ def vit_flash_attn_wrapper(
max_seqlen: torch.Tensor,
batch_size: int,
is_rocm_aiter: bool,
use_upstream_fa: bool,
) -> torch.Tensor:
return torch.ops.vllm.flash_attn_maxseqlen_wrapper(
q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter, use_upstream_fa
q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter
)


Expand Down
2 changes: 0 additions & 2 deletions vllm/model_executor/models/dots_ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,12 +294,10 @@ def __init__(
torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
self.use_upstream_fa = False

self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
self.use_upstream_fa,
attn_backend_override=attn_backend_override,
)
)
Expand Down
2 changes: 0 additions & 2 deletions vllm/model_executor/models/ernie45_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,12 +201,10 @@ def __init__(
attn_backend_override=attn_backend_override,
)

self.use_upstream_fa = False

self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
self.use_upstream_fa,
attn_backend_override=attn_backend_override,
)
)
Expand Down
2 changes: 0 additions & 2 deletions vllm/model_executor/models/glm4_1v.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,12 +296,10 @@ def __init__(
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
self.use_upstream_fa = False

self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
self.use_upstream_fa,
attn_backend_override=attn_backend_override,
)
)
Expand Down
1 change: 0 additions & 1 deletion vllm/model_executor/models/keye.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,6 @@ def __init__(
self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
use_upstream_fa=False,
attn_backend_override=attn_backend_override,
)
)
Expand Down
9 changes: 0 additions & 9 deletions vllm/model_executor/models/paddleocr_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,6 @@ def __init__(
prefix: str = "",
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
attn_backend_override: AttentionBackendEnum | None = None,
use_upstream_fa: bool = False,
) -> None:
super().__init__()

Expand Down Expand Up @@ -612,11 +611,9 @@ def __init__(
)

self.attn_backend = attn_backend
self.use_upstream_fa = use_upstream_fa
self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
self.use_upstream_fa,
attn_backend_override=attn_backend_override,
)
)
Expand Down Expand Up @@ -680,7 +677,6 @@ def forward(
max_seqlen,
batch_size,
self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA,
self.use_upstream_fa,
)
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
outputs = []
Expand Down Expand Up @@ -783,7 +779,6 @@ def __init__(
*,
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
attn_backend_override: AttentionBackendEnum | None = None,
use_upstream_fa: bool = False,
):
super().__init__()
self.embed_dim = config.hidden_size
Expand All @@ -796,7 +791,6 @@ def __init__(
prefix=f"{prefix}.self_attn",
attn_backend=attn_backend,
attn_backend_override=attn_backend_override,
use_upstream_fa=use_upstream_fa,
)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = SiglipMLP(
Expand Down Expand Up @@ -852,13 +846,11 @@ def __init__(
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
self.use_upstream_fa = False
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
} and check_upstream_fa_availability(torch.get_default_dtype()):
self.attn_backend = AttentionBackendEnum.FLASH_ATTN
self.use_upstream_fa = True
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
Expand All @@ -875,7 +867,6 @@ def __init__(
prefix=f"{prefix}.layers.{layer_idx}",
attn_backend=self.attn_backend,
attn_backend_override=attn_backend_override,
use_upstream_fa=self.use_upstream_fa,
)
for layer_idx in range(config.num_hidden_layers)
]
Expand Down
16 changes: 0 additions & 16 deletions vllm/model_executor/models/qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,6 @@ def __init__(
prefix: str = "",
use_data_parallel: bool = False,
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
use_upstream_fa: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
Expand Down Expand Up @@ -344,24 +343,15 @@ def __init__(
disable_tp=use_data_parallel,
)
self.attn_backend = attn_backend
self.use_upstream_fa = use_upstream_fa
self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
self.use_upstream_fa,
attn_backend_override=attn_backend_override,
)
)
# On ROCm with FLASH_ATTN backend, upstream flash_attn is used
from vllm.platforms import current_platform

if (
current_platform.is_rocm()
and self.attn_backend == AttentionBackendEnum.FLASH_ATTN
):
self.use_upstream_fa = True
if current_platform.is_xpu():
self.use_upstream_fa = False
self.is_flash_attn_backend = self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
Expand Down Expand Up @@ -415,7 +405,6 @@ def forward(
max_seqlen,
batch_size,
self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA,
self.use_upstream_fa,
)
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
# Execute attention entry by entry for speed & less VRAM.
Expand Down Expand Up @@ -459,7 +448,6 @@ def __init__(
prefix: str = "",
use_data_parallel: bool = False,
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
use_upstream_fa: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
Expand All @@ -475,7 +463,6 @@ def __init__(
prefix=f"{prefix}.attn",
use_data_parallel=use_data_parallel,
attn_backend=attn_backend,
use_upstream_fa=use_upstream_fa,
attn_backend_override=attn_backend_override,
)
self.mlp = Qwen2_5_VisionMLP(
Expand Down Expand Up @@ -644,7 +631,6 @@ def __init__(
is_neox_style=True,
)

use_upstream_fa = False
self.attn_backend = get_vit_attn_backend(
head_size=head_dim,
dtype=torch.get_default_dtype(),
Expand All @@ -654,7 +640,6 @@ def __init__(
self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
use_upstream_fa,
attn_backend_override=attn_backend_override,
)
)
Expand All @@ -681,7 +666,6 @@ def __init__(
prefix=f"{prefix}.blocks.{layer_idx}",
use_data_parallel=use_data_parallel,
attn_backend=self.attn_backend,
use_upstream_fa=use_upstream_fa,
attn_backend_override=attn_backend_override,
)
for layer_idx in range(depth)
Expand Down
2 changes: 0 additions & 2 deletions vllm/model_executor/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,12 +335,10 @@ def __init__(
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
self.use_upstream_fa = False

self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
self.use_upstream_fa,
attn_backend_override=attn_backend_override,
)
)
Expand Down
11 changes: 0 additions & 11 deletions vllm/model_executor/models/qwen3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,6 @@ def __init__(
prefix: str = "",
use_data_parallel: bool = False,
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
use_upstream_fa: bool = False,
) -> None:
super().__init__()
if norm_layer is None:
Expand All @@ -216,7 +215,6 @@ def __init__(
prefix=f"{prefix}.attn",
use_data_parallel=use_data_parallel,
attn_backend=attn_backend,
use_upstream_fa=use_upstream_fa,
)
self.mlp = Qwen3_VisionMLP(
dim,
Expand Down Expand Up @@ -377,14 +375,6 @@ def __init__(
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
use_upstream_fa = False
if (
self.attn_backend != AttentionBackendEnum.FLASH_ATTN
and self.attn_backend != AttentionBackendEnum.ROCM_AITER_FA
and check_upstream_fa_availability(torch.get_default_dtype())
):
self.attn_backend = AttentionBackendEnum.FLASH_ATTN
use_upstream_fa = True

if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
Expand All @@ -406,7 +396,6 @@ def __init__(
prefix=f"{prefix}.blocks.{layer_idx}",
use_data_parallel=use_data_parallel,
attn_backend=self.attn_backend,
use_upstream_fa=use_upstream_fa,
)
for layer_idx in range(vision_config.depth)
]
Expand Down
2 changes: 0 additions & 2 deletions vllm/model_executor/models/siglip2navit.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,12 +255,10 @@ def __init__(
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
self.use_upstream_fa = False

self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
self.use_upstream_fa,
attn_backend_override=attn_backend_override,
)
)
Expand Down
Loading