@@ -20,22 +20,22 @@ def test_env(name: str, device: str, monkeypatch):
2020
2121 if device == "cpu" :
2222 with patch ("vllm.attention.selector.is_cpu" , return_value = True ):
23- backend = which_attn_to_use (8 , 16 , 8 , None , torch .float16 ,
24- torch . float16 , 16 )
23+ backend = which_attn_to_use (16 , None , torch . float16 , torch .float16 ,
24+ 16 , False )
2525 assert backend .name == "TORCH_SDPA"
2626 elif device == "hip" :
2727 with patch ("vllm.attention.selector.is_hip" , return_value = True ):
28- backend = which_attn_to_use (8 , 16 , 8 , None , torch .float16 ,
29- torch . float16 , 16 )
28+ backend = which_attn_to_use (16 , None , torch . float16 , torch .float16 ,
29+ 16 , False )
3030 assert backend .name == "ROCM_FLASH"
3131 elif device == "openvino" :
3232 with patch ("vllm.attention.selector.is_openvino" , return_value = True ):
33- backend = which_attn_to_use (8 , 16 , 8 , None , torch .float16 ,
34- torch . float16 , 16 )
33+ backend = which_attn_to_use (16 , None , torch . float16 , torch .float16 ,
34+ 16 , False )
3535 assert backend .name == "OPENVINO"
3636 else :
37- backend = which_attn_to_use (8 , 16 , 8 , None , torch .float16 ,
38- torch . float16 , 16 )
37+ backend = which_attn_to_use (16 , None , torch . float16 , torch .float16 , 16 ,
38+ False )
3939 assert backend .name == name
4040
4141
@@ -46,37 +46,42 @@ def test_flash_attn(monkeypatch):
4646
4747 # Unsupported CUDA arch
4848 with patch ("torch.cuda.get_device_capability" , return_value = (7 , 5 )):
49- backend = which_attn_to_use (8 , 16 , 8 , None , torch .float16 , None , 16 )
49+ backend = which_attn_to_use (16 , None , torch .float16 , None , 16 , False )
5050 assert backend .name != STR_FLASH_ATTN_VAL
5151
5252 # Unsupported data type
53- backend = which_attn_to_use (8 , 16 , 8 , None , torch .float8_e4m3fn , None , 16 )
53+ backend = which_attn_to_use (16 , None , torch .float8_e4m3fn , None , 16 , False )
5454 assert backend .name != STR_FLASH_ATTN_VAL
5555
5656 # Unsupported kv cache data type
57- backend = which_attn_to_use (8 , 16 , 8 , None , torch .float16 , "fp8" , 16 )
57+ backend = which_attn_to_use (16 , None , torch .float16 , "fp8" , 16 , False )
5858 assert backend .name != STR_FLASH_ATTN_VAL
5959
6060 # Unsupported block size
61- backend = which_attn_to_use (8 , 16 , 8 , None , torch .float16 , None , 8 )
61+ backend = which_attn_to_use (16 , None , torch .float16 , None , 8 , False )
6262 assert backend .name != STR_FLASH_ATTN_VAL
6363
6464 # Unsupported sliding window
65- backend = which_attn_to_use (8 , 16 , 8 , 1 , torch .float16 , None , 16 )
65+ backend = which_attn_to_use (16 , 1 , torch .float16 , None , 16 , False )
6666 assert backend .name != STR_FLASH_ATTN_VAL
6767
6868 # flash-attn is not installed
6969 with patch .dict ('sys.modules' , {'vllm_flash_attn' : None }):
70- backend = which_attn_to_use (8 , 16 , 8 , None , torch .float16 , None , 16 )
70+ backend = which_attn_to_use (16 , None , torch .float16 , None , 16 , False )
7171 assert backend .name != STR_FLASH_ATTN_VAL
7272
7373 # Unsupported head size
74- backend = which_attn_to_use (8 , 17 , 8 , None , torch .float16 , None , 16 )
74+ backend = which_attn_to_use (17 , None , torch .float16 , None , 16 , False )
75+ assert backend .name != STR_FLASH_ATTN_VAL
76+
77+ # Attention-free models should bypass env and use PlaceholderAttention
78+ backend = which_attn_to_use (16 , None , torch .float16 , torch .float16 , 16 ,
79+ True )
7580 assert backend .name != STR_FLASH_ATTN_VAL
7681
7782
7883def test_invalid_env (monkeypatch ):
7984 """Throw an exception if the backend name is invalid."""
8085 override_backend_env_variable (monkeypatch , STR_INVALID_VAL )
8186 with pytest .raises (ValueError ):
82- which_attn_to_use (8 , 16 , 8 , None , torch .float16 , None , 16 )
87+ which_attn_to_use (16 , None , torch .float16 , None , 16 , False )
0 commit comments