Skip to content
57 changes: 7 additions & 50 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,53 +56,28 @@
logger = init_logger(__name__)


def check_upstream_fa_availability(dtype: torch.dtype):
if (
dtype in (torch.float16, torch.bfloat16)
and current_platform.is_cuda()
and current_platform.has_device_capability(80)
):
from transformers.utils import is_flash_attn_2_available

return is_flash_attn_2_available()
if current_platform.is_rocm():
from importlib.util import find_spec

return find_spec("flash_attn") is not None
return False


def maybe_get_vit_flash_attn_backend(
attn_backend: AttentionBackendEnum,
use_upstream_fa: bool,
attn_backend_override: AttentionBackendEnum | None = None,
) -> tuple[AttentionBackendEnum, Callable | None]:
if current_platform.is_rocm():
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
attn_backend = AttentionBackendEnum.ROCM_AITER_FA

elif (
check_upstream_fa_availability(torch.get_default_dtype())
attn_backend_override is None
and on_gfx9()
and attn_backend_override is None
and attn_backend == AttentionBackendEnum.FLASH_ATTN
):
attn_backend = AttentionBackendEnum.FLASH_ATTN
use_upstream_fa = True
pass
else:
return AttentionBackendEnum.TORCH_SDPA, None

elif current_platform.is_cuda():
if (
attn_backend != AttentionBackendEnum.FLASH_ATTN
and check_upstream_fa_availability(torch.get_default_dtype())
):
attn_backend = AttentionBackendEnum.FLASH_ATTN
use_upstream_fa = True
pass
elif current_platform.is_xpu():
assert attn_backend == AttentionBackendEnum.FLASH_ATTN, (
"XPU platform only supports FLASH_ATTN as vision attention backend."
)
use_upstream_fa = False
pass
else:
return AttentionBackendEnum.TORCH_SDPA, None
Comment on lines 81 to 82

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Restore CUDA path in vit flash-attn selection

maybe_get_vit_flash_attn_backend now returns TORCH_SDPA for any platform that is neither ROCm nor XPU. CUDA is caught by this else branch, so even when get_vit_attn_backend selects FLASH_ATTN the function forces Torch SDPA and never loads the flash attention kernel. This effectively disables flash attention for all vision models on CUDA, degrading the intended fast path everywhere and ignoring user overrides.

Useful? React with 👍 / 👎.


Expand All @@ -113,10 +88,7 @@ def maybe_get_vit_flash_attn_backend(
if attn_backend == AttentionBackendEnum.ROCM_AITER_FA:
from aiter import flash_attn_varlen_func
else:
if use_upstream_fa:
from flash_attn import flash_attn_varlen_func
else:
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
Copy link
Collaborator

@tjtanaa tjtanaa Nov 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vllm/attention/utils/fa_utils.py does not have the logic for ROCm, flash_attn_varlen_func will be a None object if imported this way.

We can keep the import statement from flash_attn import flash_attn_varlen_func for now. Else we have to add this from flash_attn import flash_attn_varlen_func import statement into the vllm/attention/utils/fa_utils.py when platform is rocm.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this import to fa_utils as it looks like the most simple way of it. And except message tells user to install upstream fa when import error is raised. Please check if that works, thanks!

else:
flash_attn_varlen_func = None

Expand Down Expand Up @@ -501,11 +473,6 @@ def __init__(
attn_backend_override=attn_backend_override,
)

# Some auto-selected backends can be upgraded
# to upstream flash attention if available.
# If vllm native fa is selected, we use it directly.
use_upstream_fa = False

self.attn_backend = (
backend
if backend
Expand All @@ -521,7 +488,6 @@ def __init__(
self.attn_backend, self._flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
use_upstream_fa,
attn_backend_override=attn_backend_override,
)
)
Expand All @@ -531,17 +497,8 @@ def __init__(
AttentionBackendEnum.ROCM_AITER_FA,
}

# this condition is just to make sure that the
# use_upstream_fa in the log is correct
if (
current_platform.is_rocm()
and self.attn_backend == AttentionBackendEnum.FLASH_ATTN
):
use_upstream_fa = True

logger.info_once(
f"MultiHeadAttention attn_backend: {self.attn_backend}, "
f"use_upstream_fa: {use_upstream_fa}"
f"Using {self.attn_backend} for MultiHeadAttention in multimodal encoder."
)

def forward(
Expand Down
10 changes: 2 additions & 8 deletions vllm/attention/ops/vit_attn_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,11 @@ def flash_attn_maxseqlen_wrapper(
max_seqlen: torch.Tensor,
batch_size: int,
is_rocm_aiter: bool,
use_upstream_fa: bool,
) -> torch.Tensor:
if is_rocm_aiter:
from aiter import flash_attn_varlen_func
else:
if use_upstream_fa:
from flash_attn import flash_attn_varlen_func
else:
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
Copy link
Collaborator

@tjtanaa tjtanaa Nov 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

like wise,

vllm/attention/utils/fa_utils.py does not have the logic for ROCm, flash_attn_varlen_func will be a None object if imported this way.

We can keep the import statement from flash_attn import flash_attn_varlen_func for now. Else we have to add this from flash_attn import flash_attn_varlen_func import statement into the vllm/attention/utils/fa_utils.py when platform is rocm.

q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
output = flash_attn_varlen_func(
q,
Expand All @@ -62,7 +58,6 @@ def flash_attn_maxseqlen_wrapper_fake(
max_seqlen: torch.Tensor,
batch_size: int,
is_rocm_aiter: bool,
use_upstream_fa: bool,
) -> torch.Tensor:
b, s, h, d = q.shape
return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device)
Expand All @@ -83,10 +78,9 @@ def vit_flash_attn_wrapper(
max_seqlen: torch.Tensor,
batch_size: int,
is_rocm_aiter: bool,
use_upstream_fa: bool,
) -> torch.Tensor:
return torch.ops.vllm.flash_attn_maxseqlen_wrapper(
q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter, use_upstream_fa
q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter
)


Expand Down
8 changes: 8 additions & 0 deletions vllm/attention/utils/fa_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,14 @@
reshape_and_cache_flash = ops.reshape_and_cache_flash
flash_attn_varlen_func = ops.flash_attn_varlen_func
get_scheduler_metadata = ops.get_scheduler_metadata
elif current_platform.is_rocm():
try:
from flash_attn import flash_attn_varlen_func # noqa: F401
except ImportError as e:
raise ImportError(
"Rocm platform requires upstream flash-attn "
"to be installed. Please install flash-attn first."
) from e


def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
Expand Down
8 changes: 0 additions & 8 deletions vllm/model_executor/models/dots_ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import (
check_upstream_fa_availability,
maybe_get_vit_flash_attn_backend,
)
from vllm.config import VllmConfig
Expand Down Expand Up @@ -294,12 +293,10 @@ def __init__(
torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
self.use_upstream_fa = False

self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
self.use_upstream_fa,
attn_backend_override=attn_backend_override,
)
)
Expand Down Expand Up @@ -569,11 +566,6 @@ def __init__(
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
if (
self.attn_backend != AttentionBackendEnum.FLASH_ATTN
and check_upstream_fa_availability(torch.get_default_dtype())
):
self.attn_backend = AttentionBackendEnum.FLASH_ATTN
self.out_hidden_size = config.hidden_size
# Keep blocks for compatibility with other vision towers
num_layers = (
Expand Down
9 changes: 0 additions & 9 deletions vllm/model_executor/models/ernie45_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@

from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import (
check_upstream_fa_availability,
maybe_get_vit_flash_attn_backend,
)
from vllm.config import VllmConfig
Expand Down Expand Up @@ -201,12 +200,9 @@ def __init__(
attn_backend_override=attn_backend_override,
)

self.use_upstream_fa = False

self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
self.use_upstream_fa,
attn_backend_override=attn_backend_override,
)
)
Expand Down Expand Up @@ -498,11 +494,6 @@ def __init__(
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
if (
self.attn_backend != AttentionBackendEnum.FLASH_ATTN
and check_upstream_fa_availability(torch.get_default_dtype())
):
self.attn_backend = AttentionBackendEnum.FLASH_ATTN

@property
def dtype(self) -> torch.dtype:
Expand Down
12 changes: 1 addition & 11 deletions vllm/model_executor/models/glm4_1v.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,7 @@
from transformers.video_utils import VideoMetadata

from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import (
check_upstream_fa_availability,
maybe_get_vit_flash_attn_backend,
)
from vllm.attention.layer import maybe_get_vit_flash_attn_backend
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size, parallel_state
Expand Down Expand Up @@ -296,12 +293,10 @@ def __init__(
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
self.use_upstream_fa = False

self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
self.use_upstream_fa,
attn_backend_override=attn_backend_override,
)
)
Expand Down Expand Up @@ -730,11 +725,6 @@ def __init__(
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
if (
self.attn_backend != AttentionBackendEnum.FLASH_ATTN
and check_upstream_fa_availability(torch.get_default_dtype())
):
self.attn_backend = AttentionBackendEnum.FLASH_ATTN

@property
def dtype(self) -> torch.dtype:
Expand Down
1 change: 0 additions & 1 deletion vllm/model_executor/models/keye.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,6 @@ def __init__(
self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
use_upstream_fa=False,
attn_backend_override=attn_backend_override,
)
)
Expand Down
15 changes: 0 additions & 15 deletions vllm/model_executor/models/paddleocr_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@

from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import (
check_upstream_fa_availability,
maybe_get_vit_flash_attn_backend,
)
from vllm.attention.ops.vit_attn_wrappers import (
Expand Down Expand Up @@ -582,7 +581,6 @@ def __init__(
prefix: str = "",
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
attn_backend_override: AttentionBackendEnum | None = None,
use_upstream_fa: bool = False,
) -> None:
super().__init__()

Expand Down Expand Up @@ -612,11 +610,9 @@ def __init__(
)

self.attn_backend = attn_backend
self.use_upstream_fa = use_upstream_fa
self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
self.use_upstream_fa,
attn_backend_override=attn_backend_override,
)
)
Expand Down Expand Up @@ -680,7 +676,6 @@ def forward(
max_seqlen,
batch_size,
self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA,
self.use_upstream_fa,
)
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
outputs = []
Expand Down Expand Up @@ -783,7 +778,6 @@ def __init__(
*,
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
attn_backend_override: AttentionBackendEnum | None = None,
use_upstream_fa: bool = False,
):
super().__init__()
self.embed_dim = config.hidden_size
Expand All @@ -796,7 +790,6 @@ def __init__(
prefix=f"{prefix}.self_attn",
attn_backend=attn_backend,
attn_backend_override=attn_backend_override,
use_upstream_fa=use_upstream_fa,
)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = SiglipMLP(
Expand Down Expand Up @@ -852,13 +845,6 @@ def __init__(
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
self.use_upstream_fa = False
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
} and check_upstream_fa_availability(torch.get_default_dtype()):
self.attn_backend = AttentionBackendEnum.FLASH_ATTN
self.use_upstream_fa = True
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
Expand All @@ -875,7 +861,6 @@ def __init__(
prefix=f"{prefix}.layers.{layer_idx}",
attn_backend=self.attn_backend,
attn_backend_override=attn_backend_override,
use_upstream_fa=self.use_upstream_fa,
)
for layer_idx in range(config.num_hidden_layers)
]
Expand Down
Loading