Skip to content

Commit 6b6742e

Browse files
rasmithRandall Smith
authored andcommitted
[CI/Build][AMD] Fix import errors in tests/kernels/attention (vllm-project#29032)
Signed-off-by: Randall Smith <ransmith@amd.com> Co-authored-by: Randall Smith <ransmith@amd.com> Signed-off-by: LuminolT <lumischen01@gmail.com>
1 parent 0682b9a commit 6b6742e

File tree

6 files changed

+49
-15
lines changed

6 files changed

+49
-15
lines changed

tests/kernels/attention/test_cascade_flash_attn.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,19 @@
77

88
from vllm.platforms import current_platform
99
from vllm.v1.attention.backends.flash_attn import cascade_attention, merge_attn_states
10-
from vllm.vllm_flash_attn import (
11-
fa_version_unsupported_reason,
12-
flash_attn_varlen_func,
13-
is_fa_version_supported,
14-
)
10+
11+
try:
12+
from vllm.vllm_flash_attn import (
13+
fa_version_unsupported_reason,
14+
flash_attn_varlen_func,
15+
is_fa_version_supported,
16+
)
17+
except ImportError:
18+
if current_platform.is_rocm():
19+
pytest.skip(
20+
"vllm_flash_attn is not supported for vLLM on ROCm.",
21+
allow_module_level=True,
22+
)
1523

1624
NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
1725
HEAD_SIZES = [128, 192, 256]

tests/kernels/attention/test_flash_attn.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,20 @@
66
import torch
77

88
from vllm.platforms import current_platform
9-
from vllm.vllm_flash_attn import (
10-
fa_version_unsupported_reason,
11-
flash_attn_varlen_func,
12-
is_fa_version_supported,
13-
)
9+
10+
try:
11+
from vllm.vllm_flash_attn import (
12+
fa_version_unsupported_reason,
13+
flash_attn_varlen_func,
14+
is_fa_version_supported,
15+
)
16+
except ImportError:
17+
if current_platform.is_rocm():
18+
pytest.skip(
19+
"vllm_flash_attn is not supported for vLLM on ROCm.",
20+
allow_module_level=True,
21+
)
22+
1423

1524
NUM_HEADS = [(4, 4), (8, 2)]
1625
HEAD_SIZES = [40, 72, 80, 128, 256]

tests/kernels/attention/test_flashinfer.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,20 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44

5-
import flashinfer
65
import pytest
7-
import torch
86

97
from vllm.platforms import current_platform
108

9+
try:
10+
import flashinfer
11+
except ImportError:
12+
if current_platform.is_rocm():
13+
pytest.skip(
14+
"flashinfer is not supported for vLLM on ROCm.", allow_module_level=True
15+
)
16+
17+
import torch
18+
1119
NUM_HEADS = [(32, 8), (6, 1)]
1220
HEAD_SIZES = [128, 256]
1321
BLOCK_SIZES = [16, 32]

tests/kernels/attention/test_flashinfer_mla_decode.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import pytest
44
import torch
55
import torch.nn.functional as F
6-
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
76
from torch import Tensor
87

98
from vllm.platforms import current_platform
@@ -15,6 +14,8 @@
1514
reason="FlashInfer MLA Requires compute capability of 10 or above.",
1615
allow_module_level=True,
1716
)
17+
else:
18+
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
1819

1920

2021
def ref_mla(

tests/kernels/attention/test_flashinfer_trtllm_attention.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4-
import flashinfer
54
import pytest
65
import torch
76

@@ -16,6 +15,8 @@
1615
pytest.skip(
1716
"This TRTLLM kernel requires NVIDIA Blackwell.", allow_module_level=True
1817
)
18+
else:
19+
import flashinfer
1920

2021
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
2122
FP8_DTYPE = current_platform.fp8_dtype()

tests/kernels/moe/test_flashinfer.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,14 @@
2222
from vllm.model_executor.layers.quantization.utils.fp8_utils import input_to_float8
2323
from vllm.model_executor.models.llama4 import Llama4MoE
2424
from vllm.platforms import current_platform
25-
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
25+
26+
try:
27+
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
28+
except ImportError:
29+
if current_platform.is_rocm():
30+
pytest.skip(
31+
"flashinfer not supported for vLLM on ROCm", allow_module_level=True
32+
)
2633

2734
if not has_flashinfer_cutlass_fused_moe() or not current_platform.has_device_capability(
2835
90

0 commit comments

Comments
 (0)