Skip to content

Commit a38b8af

Browse files
authored
[NVIDIA] Add SM100 Flashinfer Cutlass MoE fp8 backend (#22357)
Signed-off-by: Amir Klein <203507526+amirkl94@users.noreply.github.com>
1 parent 21dce80 commit a38b8af

File tree

6 files changed

+612
-138
lines changed

6 files changed

+612
-138
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -630,6 +630,7 @@ steps:
630630
- vllm/model_executor/layers/fused_moe/cutlass_moe.py
631631
- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
632632
- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
633+
- vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
633634
- vllm/v1/attention/backends/flashinfer.py
634635
- vllm/compilation/fusion.py
635636
- vllm/compilation/fusion_attn.py
@@ -650,6 +651,7 @@ steps:
650651
# Fusion
651652
- pytest -v -s tests/compile/test_fusion_all_reduce.py
652653
- pytest -v -s tests/compile/test_fusion_attn.py::test_attention_quant_pattern
654+
- pytest -v -s tests/kernels/moe/test_flashinfer.py
653655

654656
##### 1 GPU test #####
655657
##### multi gpus test #####
Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from dataclasses import dataclass
4+
5+
import pytest
6+
import torch
7+
8+
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
9+
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
10+
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
11+
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
12+
apply_flashinfer_per_tensor_scale_fp8, flashinfer_cutlass_moe_fp8,
13+
register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights,
14+
swap_w13_to_w31)
15+
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
16+
input_to_float8)
17+
from vllm.model_executor.models.llama4 import Llama4MoE
18+
from vllm.platforms import current_platform
19+
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
20+
21+
if not has_flashinfer_cutlass_fused_moe(
22+
) or not current_platform.has_device_capability(100):
23+
pytest.skip("Requires flashinfer_cutlass_fused_moe and nvfp4 support",
24+
allow_module_level=True)
25+
26+
NUM_EXPERTS = [16]
27+
TOP_KS = [1]
28+
29+
MNK_FACTORS = [
30+
(256, 8192, 5120),
31+
(256, 4096, 5120),
32+
(127, 8192, 5120),
33+
(127, 4096, 5120),
34+
(10, 8192, 5120),
35+
(10, 4096, 5120),
36+
(1, 8192, 5120),
37+
(1, 4096, 5120),
38+
]
39+
40+
vllm_config = VllmConfig(parallel_config=ParallelConfig(
41+
pipeline_parallel_size=1))
42+
vllm_config.scheduler_config.max_num_seqs = 128
43+
vllm_config.scheduler_config.max_model_len = 8192
44+
45+
46+
def quant_fp8_per_tensor_batches(a):
47+
num_batches = a.size(0)
48+
a_quant = []
49+
a_scales = []
50+
51+
for i in range(num_batches):
52+
a_fp8, a_global_sf = input_to_float8(a[i])
53+
a_global_sf = 1.0 / a_global_sf
54+
a_quant.append(a_fp8)
55+
a_scales.append(a_global_sf)
56+
57+
result_a_quant = torch.stack(a_quant)
58+
result_a_scales = torch.stack(a_scales)
59+
60+
return result_a_quant, result_a_scales
61+
62+
63+
@dataclass
64+
class TestData:
65+
hidden_states: torch.Tensor
66+
w13_quantized: torch.Tensor
67+
w2_quantized: torch.Tensor
68+
a1_scale: torch.Tensor
69+
a2_scale: torch.Tensor
70+
w13_weight_scale: torch.Tensor
71+
w2_weight_scale: torch.Tensor
72+
layer: torch.nn.Module
73+
74+
@staticmethod
75+
def make_moe_tensors_8bit(m: int, k: int, n: int, e: int,
76+
reorder: bool) -> "TestData":
77+
hidden_states = torch.randn(
78+
(m, k), device="cuda", dtype=torch.bfloat16) / 10
79+
w13 = torch.randn((e, 2 * n, k), device="cuda", dtype=torch.bfloat16)
80+
w2 = torch.randn((e, k, n), device="cuda", dtype=torch.bfloat16)
81+
82+
# Scale to fp8
83+
_, a1_scale = input_to_float8(hidden_states)
84+
a1_scale = 1.0 / a1_scale
85+
a2_scale = torch.scalar_tensor(1.0).to(device="cuda").to(
86+
dtype=torch.float32)
87+
w13_quantized, w13_weight_scale = quant_fp8_per_tensor_batches(w13)
88+
w2_quantized, w2_weight_scale = quant_fp8_per_tensor_batches(w2)
89+
90+
layer = torch.nn.Module()
91+
layer.w13_weight = w13_quantized.clone()
92+
layer.w2_weight = w2_quantized.clone()
93+
layer.w13_input_scale = a1_scale
94+
layer.w2_input_scale = a2_scale
95+
layer.w13_weight_scale = w13_weight_scale
96+
layer.w2_weight_scale = w2_weight_scale
97+
98+
register_moe_scaling_factors(layer)
99+
100+
# flashinfer expects swapped rows for w13
101+
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
102+
if reorder:
103+
rotate_flashinfer_fp8_moe_weights(layer.w13_weight,
104+
layer.w2_weight)
105+
layer.custom_routing_function = Llama4MoE.custom_routing_function
106+
layer.intermediate_size_per_partition = n
107+
layer.ep_rank = 0
108+
layer.local_num_experts = e
109+
110+
return TestData(
111+
hidden_states=hidden_states,
112+
w13_quantized=w13_quantized,
113+
w2_quantized=w2_quantized,
114+
a1_scale=a1_scale,
115+
a2_scale=a2_scale,
116+
w13_weight_scale=w13_weight_scale,
117+
w2_weight_scale=w2_weight_scale,
118+
layer=layer,
119+
)
120+
121+
122+
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
123+
@pytest.mark.parametrize("e", NUM_EXPERTS)
124+
@pytest.mark.parametrize("topk", TOP_KS)
125+
def test_flashinfer_per_tensor_moe_fp8_no_graph(
126+
m: int,
127+
n: int,
128+
k: int,
129+
e: int,
130+
topk: int,
131+
monkeypatch,
132+
):
133+
current_platform.seed_everything(7)
134+
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
135+
with set_current_vllm_config(vllm_config):
136+
td = TestData.make_moe_tensors_8bit(m, k, n, e, reorder=True)
137+
138+
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
139+
topk_weights, topk_ids = FusedMoE.select_experts(
140+
hidden_states=td.hidden_states,
141+
router_logits=score,
142+
use_grouped_topk=False,
143+
top_k=topk,
144+
renormalize=False,
145+
custom_routing_function=Llama4MoE.custom_routing_function,
146+
scoring_func="softmax")
147+
148+
output = fused_experts(
149+
td.hidden_states,
150+
td.w13_quantized,
151+
td.w2_quantized,
152+
topk_weights=topk_weights,
153+
topk_ids=topk_ids,
154+
inplace=False,
155+
activation="silu",
156+
use_fp8_w8a8=True,
157+
per_channel_quant=False,
158+
global_num_experts=e,
159+
expert_map=None,
160+
w1_scale=td.w13_weight_scale,
161+
w2_scale=td.w2_weight_scale,
162+
a1_scale=td.a1_scale,
163+
a2_scale=td.a2_scale,
164+
apply_router_weight_on_input=True,
165+
)
166+
167+
flashinfer_output = apply_flashinfer_per_tensor_scale_fp8(
168+
layer=td.layer,
169+
hidden_states=td.hidden_states,
170+
router_logits=score,
171+
routing_bias=None,
172+
global_num_experts=e,
173+
top_k=topk,
174+
num_expert_group=None,
175+
topk_group=None,
176+
apply_router_weight_on_input=True)
177+
178+
torch.testing.assert_close(output,
179+
flashinfer_output,
180+
atol=5.5e-2,
181+
rtol=1e-2)
182+
183+
184+
@pytest.mark.skip(
185+
"Requires flashinfer version that contains https://github.com/flashinfer-ai/flashinfer/pull/1472"
186+
)
187+
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
188+
@pytest.mark.parametrize("e", NUM_EXPERTS)
189+
@pytest.mark.parametrize("topk", TOP_KS)
190+
def test_flashinfer_cutlass_moe_fp8_no_graph(
191+
m: int,
192+
n: int,
193+
k: int,
194+
e: int,
195+
topk: int,
196+
monkeypatch,
197+
):
198+
current_platform.seed_everything(7)
199+
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
200+
with set_current_vllm_config(vllm_config):
201+
td = TestData.make_moe_tensors_8bit(m, k, n, e, reorder=False)
202+
203+
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
204+
topk_weights, topk_ids = FusedMoE.select_experts(
205+
hidden_states=td.hidden_states,
206+
router_logits=score,
207+
use_grouped_topk=False,
208+
top_k=topk,
209+
renormalize=False,
210+
custom_routing_function=Llama4MoE.custom_routing_function,
211+
scoring_func="softmax")
212+
213+
output = fused_experts(
214+
td.hidden_states,
215+
td.w13_quantized,
216+
td.w2_quantized,
217+
topk_weights=topk_weights,
218+
topk_ids=topk_ids,
219+
inplace=False,
220+
activation="silu",
221+
use_fp8_w8a8=True,
222+
per_channel_quant=False,
223+
global_num_experts=e,
224+
expert_map=None,
225+
w1_scale=td.w13_weight_scale,
226+
w2_scale=td.w2_weight_scale,
227+
a1_scale=td.a1_scale,
228+
a2_scale=td.a2_scale,
229+
apply_router_weight_on_input=True,
230+
)
231+
232+
td.layer.dp_size = 1
233+
234+
flashinfer_cutlass_output = flashinfer_cutlass_moe_fp8(
235+
td.hidden_states,
236+
td.layer,
237+
topk_weights,
238+
topk_ids,
239+
activation="silu",
240+
global_num_experts=e,
241+
expert_map=None,
242+
apply_router_weight_on_input=True,
243+
)
244+
245+
torch.testing.assert_close(output,
246+
flashinfer_cutlass_output,
247+
atol=5.5e-2,
248+
rtol=1e-2)

vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ def __init__(
6161
per_act_token_quant=False,
6262
block_shape=None,
6363
))
64-
assert quant_dtype == "nvfp4", ("Only nvfp4 quantization is "
65-
"currently supported.")
64+
assert quant_dtype in ("nvfp4", torch.float8_e4m3fn), (
65+
"Only nvfp4,fp8 quantization are currently supported.")
6666
self.ep_rank = ep_rank
6767
self.ep_size = ep_size
6868
self.tp_rank = tp_rank
@@ -122,7 +122,8 @@ def workspace_shapes(
122122
"""
123123
aq_m, aq_n = aq.shape
124124
workspace2 = ()
125-
output_shape = (aq_m, aq_n * 2)
125+
output_shape = (aq_m, aq_n * 2) if self.quant_dtype != \
126+
torch.float8_e4m3fn else (aq_m, aq_n)
126127
workspace_dtype = a.dtype
127128
workspace1 = output_shape
128129
# The workspace is determined by `aq`, since it comes after any
@@ -151,29 +152,39 @@ def apply(
151152
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
152153
apply_router_weight_on_input: Optional[bool],
153154
):
154-
# Flashinfer CUTLASS kernel takes scalar global scales,
155-
# min because inv_scale.
156-
157-
# Ensure w1_scale and w2_scale are not None before calling view
158-
assert w1_scale is not None and w2_scale is not None, (
159-
"w1_scale and w2_scale must not "
160-
"be None for FlashInferExperts")
161-
162-
quant_scales = [
163-
self.a1_gscale,
164-
w1_scale.view(torch.int32),
165-
self.g1_alphas,
166-
self.a2_gscale,
167-
w2_scale.view(torch.int32),
168-
self.g2_alphas,
169-
]
155+
if self.quant_dtype == torch.float8_e4m3fn:
156+
quant_scales = [
157+
self.g1_alphas, self.a2_gscale, self.g2_alphas, self.a1_gscale
158+
]
159+
160+
a1q_scale = None # not passing input_sf in fp8
161+
fc1_expert_weights = w1
162+
fc2_expert_weights = w2
163+
else:
164+
# Ensure w1_scale and w2_scale are not None before calling view
165+
assert w1_scale is not None and w2_scale is not None, (
166+
"w1_scale and w2_scale must not "
167+
"be None for FlashInferExperts")
168+
# Flashinfer CUTLASS kernel takes scalar global scales,
169+
# min because inv_scale.
170+
quant_scales = [
171+
self.a1_gscale,
172+
w1_scale.view(torch.int32),
173+
self.g1_alphas,
174+
self.a2_gscale,
175+
w2_scale.view(torch.int32),
176+
self.g2_alphas,
177+
]
178+
# FlashInfer API requires weight to be long for nvfp4
179+
fc1_expert_weights = w1.view(torch.long)
180+
fc2_expert_weights = w2.view(torch.long)
181+
170182
_ = flashinfer_cutlass_fused_moe(
171183
input=hidden_states,
172184
token_selected_experts=topk_ids.to(torch.int),
173185
token_final_scales=topk_weights,
174-
# FlashInfer API requires weight to be long for nvfp4
175-
fc1_expert_weights=w1.view(torch.long),
176-
fc2_expert_weights=w2.view(torch.long),
186+
fc1_expert_weights=fc1_expert_weights,
187+
fc2_expert_weights=fc2_expert_weights,
177188
output_dtype=self.out_dtype,
178189
quant_scales=quant_scales,
179190
input_sf=a1q_scale,

0 commit comments

Comments
 (0)