|
39 | 39 | NUM_EXPERTS = [8, 64] |
40 | 40 | TOP_KS = [1, 2, 6] |
41 | 41 |
|
| 42 | +DTYPES = [torch.bfloat16] |
| 43 | + |
| 44 | +if not current_platform.is_fp8_fnuz(): |
| 45 | + DTYPES.append(torch.float8_e4m3fn) |
| 46 | + |
42 | 47 | vllm_config = VllmConfig() |
43 | 48 |
|
44 | 49 |
|
@@ -96,7 +101,7 @@ def make_tensors(config: BatchedMMConfig): |
96 | 101 | @pytest.mark.parametrize("max_tokens_per_expert", [32, 224, 512]) |
97 | 102 | @pytest.mark.parametrize("K", [128, 1024]) |
98 | 103 | @pytest.mark.parametrize("N", [128, 1024]) |
99 | | -@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16]) |
| 104 | +@pytest.mark.parametrize("dtype", DTYPES) |
100 | 105 | @pytest.mark.parametrize("block_shape", [None, [128, 128]]) |
101 | 106 | @pytest.mark.parametrize("per_act_token_quant", [False, True]) |
102 | 107 | def test_batched_mm( |
@@ -229,7 +234,7 @@ def test_batched_mm( |
229 | 234 | @pytest.mark.parametrize(("m", "n", "k"), MNK_FACTORS) |
230 | 235 | @pytest.mark.parametrize("e", NUM_EXPERTS) |
231 | 236 | @pytest.mark.parametrize("topk", TOP_KS) |
232 | | -@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16]) |
| 237 | +@pytest.mark.parametrize("dtype", DTYPES) |
233 | 238 | @pytest.mark.parametrize("per_act_token_quant", [False, True]) |
234 | 239 | @pytest.mark.parametrize("block_shape", [None, [128, 128]]) |
235 | 240 | @pytest.mark.parametrize("input_scales", [False]) |
|
0 commit comments