Skip to content

Commit fd65015

Browse files
rasmithRandall Smith
andauthored
[CI/Build] Only use supported types and features on ROCm in MoE kernel tests (#29149)
Signed-off-by: Randall Smith <ransmith@amd.com> Co-authored-by: Randall Smith <ransmith@amd.com>
1 parent 77e1c03 commit fd65015

File tree

7 files changed

+41
-2
lines changed

7 files changed

+41
-2
lines changed

tests/kernels/moe/test_batched_moe.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@
3939
NUM_EXPERTS = [8, 64]
4040
TOP_KS = [1, 2, 6]
4141

42+
DTYPES = [torch.bfloat16]
43+
44+
if not current_platform.is_fp8_fnuz():
45+
DTYPES.append(torch.float8_e4m3fn)
46+
4247
vllm_config = VllmConfig()
4348

4449

@@ -96,7 +101,7 @@ def make_tensors(config: BatchedMMConfig):
96101
@pytest.mark.parametrize("max_tokens_per_expert", [32, 224, 512])
97102
@pytest.mark.parametrize("K", [128, 1024])
98103
@pytest.mark.parametrize("N", [128, 1024])
99-
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])
104+
@pytest.mark.parametrize("dtype", DTYPES)
100105
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
101106
@pytest.mark.parametrize("per_act_token_quant", [False, True])
102107
def test_batched_mm(
@@ -229,7 +234,7 @@ def test_batched_mm(
229234
@pytest.mark.parametrize(("m", "n", "k"), MNK_FACTORS)
230235
@pytest.mark.parametrize("e", NUM_EXPERTS)
231236
@pytest.mark.parametrize("topk", TOP_KS)
232-
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])
237+
@pytest.mark.parametrize("dtype", DTYPES)
233238
@pytest.mark.parametrize("per_act_token_quant", [False, True])
234239
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
235240
@pytest.mark.parametrize("input_scales", [False])

tests/kernels/moe/test_block_fp8.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@
3131

3232
if current_platform.get_device_capability() < (9, 0):
3333
pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True)
34+
if current_platform.is_fp8_fnuz():
35+
pytest.skip(
36+
"Tests in this file require float8_e4m3fn and platform does not support",
37+
allow_module_level=True,
38+
)
3439

3540
vllm_config = VllmConfig()
3641

tests/kernels/moe/test_gpt_oss_triton_kernels.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,11 @@ class Case:
270270
@pytest.mark.parametrize("num_token", [2])
271271
@pytest.mark.parametrize("tp", [1, 2, 4, 8])
272272
def test_equiv(num_token, a_dtype, w_dtype, tp):
273+
from triton_kernels.tensor_details import layout
274+
275+
if not hasattr(layout, "make_default_matmul_mxfp4_w_layout"):
276+
pytest.skip("make_default_matmul_mxfp4_w_layout not available")
277+
273278
M = num_token
274279
E = ModelConfig.num_experts
275280
K = ModelConfig.hidden_size

tests/kernels/moe/test_modular_kernel_combinations.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,12 @@
4646
reason="Requires deep_ep or deep_gemm or pplx or flashinfer packages",
4747
)
4848

49+
if current_platform.is_fp8_fnuz():
50+
pytest.skip(
51+
"Tests in this file require float8_e4m3fn and platform does not support",
52+
allow_module_level=True,
53+
)
54+
4955

5056
def format_result(verbose, msg, ex=None):
5157
if ex is not None:

tests/kernels/moe/test_moe_permute_unpermute.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@
2323
EP_SIZE = [1, 4, 16]
2424
current_platform.seed_everything(0)
2525

26+
if current_platform.is_rocm():
27+
pytest.skip(
28+
"moe_permute_unpermute_supported is not defined for ROCm",
29+
allow_module_level=True,
30+
)
31+
2632

2733
def torch_permute(
2834
hidden_states: torch.Tensor,

tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@
1414
from vllm.utils.deep_gemm import DeepGemmQuantScaleFMT, has_deep_gemm
1515
from vllm.utils.math_utils import cdiv, round_up
1616

17+
if current_platform.is_fp8_fnuz():
18+
pytest.skip(
19+
"Tests in this file require float8_e4m3fn and platform does not support",
20+
allow_module_level=True,
21+
)
22+
1723
fp8_dtype = torch.float8_e4m3fn
1824

1925
CASES = [

tests/kernels/moe/test_triton_moe_ptpc_fp8.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,12 @@
1919

2020
vllm_config = VllmConfig()
2121

22+
if current_platform.is_fp8_fnuz():
23+
pytest.skip(
24+
"Tests in this file require float8_e4m3fn and platform does not support",
25+
allow_module_level=True,
26+
)
27+
2228

2329
def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16):
2430
"""Matrix multiplication function that supports per-token input

0 commit comments

Comments
 (0)