Skip to content

Commit 460d8bb

Browse files
Victor49152ywang96
andauthored
Remove upstream fa checks (#29471)
Signed-off-by: mingyuanm <mingyuanm@nvidia.com> Signed-off-by: Roger Wang <hey@rogerw.io> Co-authored-by: Roger Wang <hey@rogerw.io>
1 parent e2f56c3 commit 460d8bb

File tree

13 files changed

+18
-148
lines changed

13 files changed

+18
-148
lines changed

vllm/attention/layer.py

Lines changed: 7 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -56,53 +56,28 @@
5656
logger = init_logger(__name__)
5757

5858

59-
def check_upstream_fa_availability(dtype: torch.dtype):
60-
if (
61-
dtype in (torch.float16, torch.bfloat16)
62-
and current_platform.is_cuda()
63-
and current_platform.has_device_capability(80)
64-
):
65-
from transformers.utils import is_flash_attn_2_available
66-
67-
return is_flash_attn_2_available()
68-
if current_platform.is_rocm():
69-
from importlib.util import find_spec
70-
71-
return find_spec("flash_attn") is not None
72-
return False
73-
74-
7559
def maybe_get_vit_flash_attn_backend(
7660
attn_backend: AttentionBackendEnum,
77-
use_upstream_fa: bool,
7861
attn_backend_override: AttentionBackendEnum | None = None,
7962
) -> tuple[AttentionBackendEnum, Callable | None]:
8063
if current_platform.is_rocm():
8164
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
8265
attn_backend = AttentionBackendEnum.ROCM_AITER_FA
83-
8466
elif (
85-
check_upstream_fa_availability(torch.get_default_dtype())
67+
attn_backend_override is None
8668
and on_gfx9()
87-
and attn_backend_override is None
69+
and attn_backend == AttentionBackendEnum.FLASH_ATTN
8870
):
89-
attn_backend = AttentionBackendEnum.FLASH_ATTN
90-
use_upstream_fa = True
71+
pass
9172
else:
9273
return AttentionBackendEnum.TORCH_SDPA, None
93-
9474
elif current_platform.is_cuda():
95-
if (
96-
attn_backend != AttentionBackendEnum.FLASH_ATTN
97-
and check_upstream_fa_availability(torch.get_default_dtype())
98-
):
99-
attn_backend = AttentionBackendEnum.FLASH_ATTN
100-
use_upstream_fa = True
75+
pass
10176
elif current_platform.is_xpu():
10277
assert attn_backend == AttentionBackendEnum.FLASH_ATTN, (
10378
"XPU platform only supports FLASH_ATTN as vision attention backend."
10479
)
105-
use_upstream_fa = False
80+
pass
10681
else:
10782
return AttentionBackendEnum.TORCH_SDPA, None
10883

@@ -113,10 +88,7 @@ def maybe_get_vit_flash_attn_backend(
11388
if attn_backend == AttentionBackendEnum.ROCM_AITER_FA:
11489
from aiter import flash_attn_varlen_func
11590
else:
116-
if use_upstream_fa:
117-
from flash_attn import flash_attn_varlen_func
118-
else:
119-
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
91+
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
12092
else:
12193
flash_attn_varlen_func = None
12294

@@ -501,11 +473,6 @@ def __init__(
501473
attn_backend_override=attn_backend_override,
502474
)
503475

504-
# Some auto-selected backends can be upgraded
505-
# to upstream flash attention if available.
506-
# If vllm native fa is selected, we use it directly.
507-
use_upstream_fa = False
508-
509476
self.attn_backend = (
510477
backend
511478
if backend
@@ -521,7 +488,6 @@ def __init__(
521488
self.attn_backend, self._flash_attn_varlen_func = (
522489
maybe_get_vit_flash_attn_backend(
523490
self.attn_backend,
524-
use_upstream_fa,
525491
attn_backend_override=attn_backend_override,
526492
)
527493
)
@@ -531,17 +497,8 @@ def __init__(
531497
AttentionBackendEnum.ROCM_AITER_FA,
532498
}
533499

534-
# this condition is just to make sure that the
535-
# use_upstream_fa in the log is correct
536-
if (
537-
current_platform.is_rocm()
538-
and self.attn_backend == AttentionBackendEnum.FLASH_ATTN
539-
):
540-
use_upstream_fa = True
541-
542500
logger.info_once(
543-
f"MultiHeadAttention attn_backend: {self.attn_backend}, "
544-
f"use_upstream_fa: {use_upstream_fa}"
501+
f"Using {self.attn_backend} for MultiHeadAttention in multimodal encoder."
545502
)
546503

547504
def forward(

vllm/attention/ops/vit_attn_wrappers.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,11 @@ def flash_attn_maxseqlen_wrapper(
2727
max_seqlen: torch.Tensor,
2828
batch_size: int,
2929
is_rocm_aiter: bool,
30-
use_upstream_fa: bool,
3130
) -> torch.Tensor:
3231
if is_rocm_aiter:
3332
from aiter import flash_attn_varlen_func
3433
else:
35-
if use_upstream_fa:
36-
from flash_attn import flash_attn_varlen_func
37-
else:
38-
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
34+
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
3935
q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
4036
output = flash_attn_varlen_func(
4137
q,
@@ -62,7 +58,6 @@ def flash_attn_maxseqlen_wrapper_fake(
6258
max_seqlen: torch.Tensor,
6359
batch_size: int,
6460
is_rocm_aiter: bool,
65-
use_upstream_fa: bool,
6661
) -> torch.Tensor:
6762
b, s, h, d = q.shape
6863
return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device)
@@ -83,10 +78,9 @@ def vit_flash_attn_wrapper(
8378
max_seqlen: torch.Tensor,
8479
batch_size: int,
8580
is_rocm_aiter: bool,
86-
use_upstream_fa: bool,
8781
) -> torch.Tensor:
8882
return torch.ops.vllm.flash_attn_maxseqlen_wrapper(
89-
q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter, use_upstream_fa
83+
q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter
9084
)
9185

9286

vllm/attention/utils/fa_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,14 @@
1818
reshape_and_cache_flash = ops.reshape_and_cache_flash
1919
flash_attn_varlen_func = ops.flash_attn_varlen_func
2020
get_scheduler_metadata = ops.get_scheduler_metadata
21+
elif current_platform.is_rocm():
22+
try:
23+
from flash_attn import flash_attn_varlen_func # noqa: F401
24+
except ImportError as e:
25+
raise ImportError(
26+
"Rocm platform requires upstream flash-attn "
27+
"to be installed. Please install flash-attn first."
28+
) from e
2129

2230

2331
def get_flash_attn_version(requires_alibi: bool = False) -> int | None:

vllm/model_executor/models/dots_ocr.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
from vllm.attention.backends.registry import AttentionBackendEnum
1313
from vllm.attention.layer import (
14-
check_upstream_fa_availability,
1514
maybe_get_vit_flash_attn_backend,
1615
)
1716
from vllm.config import VllmConfig
@@ -294,12 +293,10 @@ def __init__(
294293
torch.get_default_dtype(),
295294
attn_backend_override=attn_backend_override,
296295
)
297-
self.use_upstream_fa = False
298296

299297
self.attn_backend, self.flash_attn_varlen_func = (
300298
maybe_get_vit_flash_attn_backend(
301299
self.attn_backend,
302-
self.use_upstream_fa,
303300
attn_backend_override=attn_backend_override,
304301
)
305302
)
@@ -569,11 +566,6 @@ def __init__(
569566
dtype=torch.get_default_dtype(),
570567
attn_backend_override=attn_backend_override,
571568
)
572-
if (
573-
self.attn_backend != AttentionBackendEnum.FLASH_ATTN
574-
and check_upstream_fa_availability(torch.get_default_dtype())
575-
):
576-
self.attn_backend = AttentionBackendEnum.FLASH_ATTN
577569
self.out_hidden_size = config.hidden_size
578570
# Keep blocks for compatibility with other vision towers
579571
num_layers = (

vllm/model_executor/models/ernie45_vl.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838

3939
from vllm.attention.backends.registry import AttentionBackendEnum
4040
from vllm.attention.layer import (
41-
check_upstream_fa_availability,
4241
maybe_get_vit_flash_attn_backend,
4342
)
4443
from vllm.config import VllmConfig
@@ -201,12 +200,9 @@ def __init__(
201200
attn_backend_override=attn_backend_override,
202201
)
203202

204-
self.use_upstream_fa = False
205-
206203
self.attn_backend, self.flash_attn_varlen_func = (
207204
maybe_get_vit_flash_attn_backend(
208205
self.attn_backend,
209-
self.use_upstream_fa,
210206
attn_backend_override=attn_backend_override,
211207
)
212208
)
@@ -498,11 +494,6 @@ def __init__(
498494
dtype=torch.get_default_dtype(),
499495
attn_backend_override=attn_backend_override,
500496
)
501-
if (
502-
self.attn_backend != AttentionBackendEnum.FLASH_ATTN
503-
and check_upstream_fa_availability(torch.get_default_dtype())
504-
):
505-
self.attn_backend = AttentionBackendEnum.FLASH_ATTN
506497

507498
@property
508499
def dtype(self) -> torch.dtype:

vllm/model_executor/models/glm4_1v.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,7 @@
4747
from transformers.video_utils import VideoMetadata
4848

4949
from vllm.attention.backends.registry import AttentionBackendEnum
50-
from vllm.attention.layer import (
51-
check_upstream_fa_availability,
52-
maybe_get_vit_flash_attn_backend,
53-
)
50+
from vllm.attention.layer import maybe_get_vit_flash_attn_backend
5451
from vllm.config import VllmConfig
5552
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
5653
from vllm.distributed import get_tensor_model_parallel_world_size, parallel_state
@@ -296,12 +293,10 @@ def __init__(
296293
dtype=torch.get_default_dtype(),
297294
attn_backend_override=attn_backend_override,
298295
)
299-
self.use_upstream_fa = False
300296

301297
self.attn_backend, self.flash_attn_varlen_func = (
302298
maybe_get_vit_flash_attn_backend(
303299
self.attn_backend,
304-
self.use_upstream_fa,
305300
attn_backend_override=attn_backend_override,
306301
)
307302
)
@@ -730,11 +725,6 @@ def __init__(
730725
dtype=torch.get_default_dtype(),
731726
attn_backend_override=attn_backend_override,
732727
)
733-
if (
734-
self.attn_backend != AttentionBackendEnum.FLASH_ATTN
735-
and check_upstream_fa_availability(torch.get_default_dtype())
736-
):
737-
self.attn_backend = AttentionBackendEnum.FLASH_ATTN
738728

739729
@property
740730
def dtype(self) -> torch.dtype:

vllm/model_executor/models/keye.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,6 @@ def __init__(
418418
self.attn_backend, self.flash_attn_varlen_func = (
419419
maybe_get_vit_flash_attn_backend(
420420
self.attn_backend,
421-
use_upstream_fa=False,
422421
attn_backend_override=attn_backend_override,
423422
)
424423
)

vllm/model_executor/models/paddleocr_vl.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333

3434
from vllm.attention.backends.registry import AttentionBackendEnum
3535
from vllm.attention.layer import (
36-
check_upstream_fa_availability,
3736
maybe_get_vit_flash_attn_backend,
3837
)
3938
from vllm.attention.ops.vit_attn_wrappers import (
@@ -582,7 +581,6 @@ def __init__(
582581
prefix: str = "",
583582
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
584583
attn_backend_override: AttentionBackendEnum | None = None,
585-
use_upstream_fa: bool = False,
586584
) -> None:
587585
super().__init__()
588586

@@ -612,11 +610,9 @@ def __init__(
612610
)
613611

614612
self.attn_backend = attn_backend
615-
self.use_upstream_fa = use_upstream_fa
616613
self.attn_backend, self.flash_attn_varlen_func = (
617614
maybe_get_vit_flash_attn_backend(
618615
self.attn_backend,
619-
self.use_upstream_fa,
620616
attn_backend_override=attn_backend_override,
621617
)
622618
)
@@ -680,7 +676,6 @@ def forward(
680676
max_seqlen,
681677
batch_size,
682678
self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA,
683-
self.use_upstream_fa,
684679
)
685680
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
686681
outputs = []
@@ -783,7 +778,6 @@ def __init__(
783778
*,
784779
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
785780
attn_backend_override: AttentionBackendEnum | None = None,
786-
use_upstream_fa: bool = False,
787781
):
788782
super().__init__()
789783
self.embed_dim = config.hidden_size
@@ -796,7 +790,6 @@ def __init__(
796790
prefix=f"{prefix}.self_attn",
797791
attn_backend=attn_backend,
798792
attn_backend_override=attn_backend_override,
799-
use_upstream_fa=use_upstream_fa,
800793
)
801794
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
802795
self.mlp = SiglipMLP(
@@ -852,13 +845,6 @@ def __init__(
852845
dtype=torch.get_default_dtype(),
853846
attn_backend_override=attn_backend_override,
854847
)
855-
self.use_upstream_fa = False
856-
if self.attn_backend not in {
857-
AttentionBackendEnum.FLASH_ATTN,
858-
AttentionBackendEnum.ROCM_AITER_FA,
859-
} and check_upstream_fa_availability(torch.get_default_dtype()):
860-
self.attn_backend = AttentionBackendEnum.FLASH_ATTN
861-
self.use_upstream_fa = True
862848
if self.attn_backend not in {
863849
AttentionBackendEnum.FLASH_ATTN,
864850
AttentionBackendEnum.TORCH_SDPA,
@@ -875,7 +861,6 @@ def __init__(
875861
prefix=f"{prefix}.layers.{layer_idx}",
876862
attn_backend=self.attn_backend,
877863
attn_backend_override=attn_backend_override,
878-
use_upstream_fa=self.use_upstream_fa,
879864
)
880865
for layer_idx in range(config.num_hidden_layers)
881866
]

0 commit comments

Comments
 (0)