2323from vllm .model_executor .layers .quantization .base_config import (
2424 QuantizationConfig )
2525from vllm .model_executor .layers .quantization .kv_cache import BaseKVCacheMethod
26+ from vllm .model_executor .models .vision import get_vit_attn_backend
2627from vllm .platforms import _Backend , current_platform
2728from vllm .utils import direct_register_custom_op
2829
@@ -55,6 +56,14 @@ def check_xformers_availability():
5556 return USE_XFORMERS_OPS
5657
5758
59+ def check_upstream_fa_availability (dtype : torch .dtype ):
60+ if dtype in (torch .float16 , torch .bfloat16 ) and current_platform .is_cuda (
61+ ) and current_platform .has_device_capability (80 ):
62+ from transformers .utils import is_flash_attn_2_available
63+ return is_flash_attn_2_available ()
64+ return False
65+
66+
5867class Attention (nn .Module , AttentionLayerBase ):
5968 """Attention layer.
6069
@@ -349,29 +358,55 @@ def __init__(
349358 f"divisible by num_kv_heads ({ self .num_kv_heads } )"
350359 self .num_queries_per_kv = self .num_heads // self .num_kv_heads
351360
361+ # During model initialization, the default dtype is set as the model
362+ # weight and activation dtype.
352363 dtype = torch .get_default_dtype ()
353- attn_backend = get_attn_backend (head_size ,
354- dtype ,
355- kv_cache_dtype = None ,
356- block_size = 16 ,
357- is_attention_free = False )
358- backend = backend_name_to_enum (attn_backend .get_name ())
364+
365+ # Determine the attention backend
366+ backend = get_vit_attn_backend (head_size = head_size , dtype = dtype )
367+
368+ # Some auto-selected backends can be upgraded
369+ # to upstream flash attention if available.
370+ # If vllm native fa is selected, we use it directly.
371+ use_upstream_fa = False
372+ if backend != _Backend .FLASH_ATTN and check_upstream_fa_availability (
373+ dtype ):
374+ backend = _Backend .FLASH_ATTN
375+ use_upstream_fa = True
376+
359377 if current_platform .is_rocm ():
360378 # currently, only torch_sdpa is supported on rocm
361379 self .attn_backend = _Backend .TORCH_SDPA
362380 else :
381+
363382 self .attn_backend = backend if backend in {
364383 _Backend .TORCH_SDPA ,
365384 _Backend .TORCH_SDPA_VLLM_V1 ,
366385 _Backend .XFORMERS ,
367386 _Backend .PALLAS_VLLM_V1 ,
368387 _Backend .ROCM_AITER_FA ,
369- } else current_platform .get_vit_attn_backend ()
388+ _Backend .FLASH_ATTN ,
389+ _Backend .FLASH_ATTN_VLLM_V1 ,
390+ } else _Backend .TORCH_SDPA
370391
371392 if (self .attn_backend == _Backend .XFORMERS
372393 and not check_xformers_availability ()):
373394 self .attn_backend = _Backend .TORCH_SDPA
374395
396+ if self .attn_backend in {
397+ _Backend .FLASH_ATTN , _Backend .FLASH_ATTN_VLLM_V1
398+ }:
399+ if use_upstream_fa :
400+ from flash_attn import flash_attn_varlen_func
401+ self ._flash_attn_varlen_func = flash_attn_varlen_func
402+ else :
403+ from vllm .vllm_flash_attn import flash_attn_varlen_func
404+ self ._flash_attn_varlen_func = flash_attn_varlen_func
405+
406+ logger .info_once (
407+ f"MultiHeadAttention attn_backend: { self .attn_backend } , "
408+ f"use_upstream_fa: { use_upstream_fa } " )
409+
375410 def forward (
376411 self ,
377412 query : torch .Tensor ,
@@ -392,7 +427,31 @@ def forward(
392427 key = torch .repeat_interleave (key , num_repeat , dim = 2 )
393428 value = torch .repeat_interleave (value , num_repeat , dim = 2 )
394429
395- if self .attn_backend == _Backend .XFORMERS :
430+ if self .attn_backend in {
431+ _Backend .FLASH_ATTN ,
432+ _Backend .FLASH_ATTN_VLLM_V1 ,
433+ }:
434+
435+ cu_seqlens_q = torch .arange (0 , (bsz + 1 ) * q_len ,
436+ step = q_len ,
437+ dtype = torch .int32 ,
438+ device = query .device )
439+ cu_seqlens_k = torch .arange (0 , (bsz + 1 ) * kv_len ,
440+ step = kv_len ,
441+ dtype = torch .int32 ,
442+ device = key .device )
443+
444+ out = self ._flash_attn_varlen_func (
445+ query .flatten (0 , 1 ),
446+ key .flatten (0 , 1 ),
447+ value .flatten (0 , 1 ),
448+ cu_seqlens_q = cu_seqlens_q ,
449+ cu_seqlens_k = cu_seqlens_k ,
450+ max_seqlen_q = q_len ,
451+ max_seqlen_k = kv_len ,
452+ softmax_scale = self .scale ,
453+ )
454+ elif self .attn_backend == _Backend .XFORMERS :
396455 from xformers import ops as xops
397456
398457 out = xops .memory_efficient_attention_forward (query ,
0 commit comments