Skip to content

Commit 6160ba4

Browse files
authored
feat: BF16 FlashInfer Fused Cutlass MOE for Hopper and Blackwell Expert Parallel (vllm-project#25503)
Signed-off-by: Duncan Moss <djm.moss@gmail.com>
1 parent fea8006 commit 6160ba4

File tree

5 files changed

+121
-6
lines changed

5 files changed

+121
-6
lines changed

vllm/envs.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@
144144
VLLM_USE_DEEP_GEMM_E8M0_HOPPER: bool = False
145145
VLLM_SKIP_DEEP_GEMM_WARMUP: bool = False
146146
VLLM_USE_FUSED_MOE_GROUPED_TOPK: bool = True
147+
VLLM_USE_FLASHINFER_MOE_FP16: bool = False
147148
VLLM_USE_FLASHINFER_MOE_FP8: bool = False
148149
VLLM_USE_FLASHINFER_MOE_FP4: bool = False
149150
VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput",
@@ -1145,6 +1146,10 @@ def get_vllm_port() -> Optional[int]:
11451146
"VLLM_USE_FUSED_MOE_GROUPED_TOPK":
11461147
lambda: bool(int(os.getenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "1"))),
11471148

1149+
# Allow use of FlashInfer MoE kernels for fused moe ops.
1150+
"VLLM_USE_FLASHINFER_MOE_FP16":
1151+
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP16", "0"))),
1152+
11481153
# Allow use of FlashInfer MoE kernels for fused moe ops.
11491154
"VLLM_USE_FLASHINFER_MOE_FP8":
11501155
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP8", "0"))),
@@ -1516,6 +1521,7 @@ def compute_hash() -> str:
15161521
"VLLM_USE_DEEP_GEMM_E8M0_HOPPER",
15171522
"VLLM_USE_TRTLLM_FP4_GEMM",
15181523
"VLLM_USE_FUSED_MOE_GROUPED_TOPK",
1524+
"VLLM_USE_FLASHINFER_MOE_FP16",
15191525
"VLLM_USE_FLASHINFER_MOE_FP8",
15201526
"VLLM_USE_FLASHINFER_MOE_FP4",
15211527
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8",

vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py

Lines changed: 60 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,10 @@ def __init__(
5252
tp_size: int = 1,
5353
):
5454
super().__init__(quant_config)
55-
assert quant_config.quant_dtype in ("nvfp4", torch.float8_e4m3fn), (
56-
"Only nvfp4,fp8 quantization are currently supported.")
55+
assert quant_config.quant_dtype in (
56+
"nvfp4", torch.float8_e4m3fn,
57+
None), ("Only nvfp4, fp8, bfloat16 and"
58+
" float16 quantization are currently supported.")
5759
self.ep_rank = ep_rank
5860
self.ep_size = ep_size
5961
self.tp_rank = tp_rank
@@ -109,8 +111,9 @@ def workspace_shapes(
109111
"""
110112
aq_m, aq_n = aq.shape
111113
workspace2 = (0, )
112-
output_shape = (aq_m, aq_n * 2) if self.quant_dtype != \
113-
torch.float8_e4m3fn else (aq_m, aq_n)
114+
output_shape = (aq_m,
115+
aq_n * 2) if self.quant_dtype == "nvfp4" else (aq_m,
116+
aq_n)
114117
workspace_dtype = a.dtype
115118
workspace1 = output_shape
116119
# The workspace is determined by `aq`, since it comes after any
@@ -135,6 +138,10 @@ def apply(
135138
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
136139
apply_router_weight_on_input: Optional[bool],
137140
):
141+
142+
assert activation == "silu", ("Only activation silu is supported in "
143+
"FlashInferExperts")
144+
138145
if self.quant_dtype == torch.float8_e4m3fn:
139146
quant_scales = [
140147
self.g1_alphas, self.a2_gscale, self.g2_alphas, self.a1_gscale
@@ -143,7 +150,7 @@ def apply(
143150
a1q_scale = None # not passing input_sf in fp8
144151
fc1_expert_weights = w1
145152
fc2_expert_weights = w2
146-
else:
153+
elif self.quant_dtype == "nvfp4":
147154
# Ensure w1_scale and w2_scale are not None before calling view
148155
assert self.w1_scale is not None and self.w2_scale is not None, (
149156
"w1_scale and w2_scale must not "
@@ -161,6 +168,11 @@ def apply(
161168
# FlashInfer API requires weight to be long for nvfp4
162169
fc1_expert_weights = w1.view(torch.long)
163170
fc2_expert_weights = w2.view(torch.long)
171+
else:
172+
quant_scales = None
173+
a1q_scale = None
174+
fc1_expert_weights = w1
175+
fc2_expert_weights = w2
164176

165177
_ = flashinfer_cutlass_fused_moe(
166178
input=hidden_states,
@@ -211,3 +223,46 @@ def flashinfer_cutlass_moe_fp4(
211223
expert_map=expert_map,
212224
apply_router_weight_on_input=apply_router_weight_on_input,
213225
)
226+
227+
228+
def flashinfer_cutlass_moe(
229+
hidden_states: torch.Tensor,
230+
w1: torch.Tensor,
231+
w2: torch.Tensor,
232+
topk_weights: torch.Tensor,
233+
topk_ids: torch.Tensor,
234+
quant_config: FusedMoEQuantConfig,
235+
inplace: bool = False,
236+
activation: str = "silu",
237+
global_num_experts: int = -1,
238+
expert_map: Optional[torch.Tensor] = None,
239+
apply_router_weight_on_input: bool = False,
240+
tp_rank: int = 0,
241+
tp_size: int = 1,
242+
ep_rank: int = 0,
243+
ep_size: int = 1,
244+
use_dp: bool = False,
245+
) -> torch.Tensor:
246+
fused_experts = mk.FusedMoEModularKernel(
247+
create_flashinfer_prepare_finalize(use_dp=use_dp),
248+
FlashInferExperts(
249+
out_dtype=hidden_states.dtype,
250+
quant_config=quant_config,
251+
tp_rank=tp_rank,
252+
tp_size=tp_size,
253+
ep_rank=ep_rank,
254+
ep_size=ep_size,
255+
))
256+
257+
return fused_experts(
258+
hidden_states=hidden_states,
259+
w1=w1,
260+
w2=w2,
261+
topk_weights=topk_weights,
262+
topk_ids=topk_ids,
263+
inplace=inplace,
264+
activation=activation,
265+
global_num_experts=global_num_experts,
266+
expert_map=expert_map,
267+
apply_router_weight_on_input=apply_router_weight_on_input,
268+
)

vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,8 @@ def prepare(
183183
dim=0,
184184
sizes=get_local_sizes(),
185185
)
186-
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
186+
if quant_config.quant_dtype == "nvfp4":
187+
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
187188

188189
return a1q, a1q_scale, None, topk_ids, topk_weights
189190

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from vllm.platforms.interface import CpuArchEnum
4040
from vllm.utils import (cdiv, direct_register_custom_op, has_deep_ep, has_pplx,
4141
round_up)
42+
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
4243
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
4344

4445
if current_platform.is_cuda_alike():
@@ -296,6 +297,40 @@ def __init__(self, moe: FusedMoEConfig):
296297
else:
297298
self.rocm_aiter_fused_experts = None # type: ignore
298299

300+
# FlashInfer CUTLASS MoE is only supported on Hopper and later GPUS
301+
self.flashinfer_cutlass_moe_enabled = (
302+
has_flashinfer_cutlass_fused_moe()
303+
and envs.VLLM_USE_FLASHINFER_MOE_FP16
304+
and self.moe.moe_parallel_config.use_ep
305+
and self.moe.moe_parallel_config.dp_size == 1
306+
and current_platform.get_device_capability()[0] >= 9)
307+
if self.flashinfer_cutlass_moe_enabled:
308+
logger.info_once(
309+
"Enabling FlashInfer CUTLASS MoE for UnquantizedFusedMoEMethod"
310+
)
311+
from functools import partial
312+
313+
from .flashinfer_cutlass_moe import flashinfer_cutlass_moe
314+
self.flashinfer_cutlass_moe = partial(
315+
flashinfer_cutlass_moe,
316+
quant_config=FUSED_MOE_UNQUANTIZED_CONFIG,
317+
tp_rank=self.moe.moe_parallel_config.tp_rank,
318+
tp_size=self.moe.moe_parallel_config.tp_size,
319+
ep_rank=self.moe.moe_parallel_config.ep_rank,
320+
ep_size=self.moe.moe_parallel_config.ep_size)
321+
else:
322+
if (self.moe.moe_parallel_config.use_ep
323+
and self.moe.moe_parallel_config.dp_size == 1):
324+
logger.info_once(
325+
"FlashInfer CUTLASS MoE is available for EP"
326+
" but not enabled, consider setting"
327+
" VLLM_USE_FLASHINFER_MOE_FP16=1 to enable it.")
328+
elif self.moe.moe_parallel_config.dp_size > 1:
329+
logger.info_once(
330+
"FlashInfer CUTLASS MoE is currently not available for DP."
331+
)
332+
self.flashinfer_cutlass_moe = None # type: ignore
333+
299334
def maybe_make_prepare_finalize(
300335
self) -> Optional[FusedMoEPrepareAndFinalize]:
301336
if self.rocm_aiter_moe_enabled:
@@ -367,6 +402,7 @@ def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor:
367402
num_pad = 256 // weight.element_size()
368403
weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]
369404
torch.cuda.empty_cache()
405+
370406
return weight
371407

372408
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
@@ -386,6 +422,12 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
386422
layer.w13_weight.data = shuffled_w13
387423
layer.w2_weight.data = shuffled_w2
388424

425+
if self.flashinfer_cutlass_moe_enabled:
426+
# Swap halves to arrange as [w3; w1] (kernel expectation)
427+
w1_w, w3_w = torch.chunk(layer.w13_weight.data, 2, dim=1)
428+
w13_weight_swapped = torch.cat([w3_w, w1_w], dim=1)
429+
layer.w13_weight.data = w13_weight_swapped.contiguous()
430+
389431
if current_platform.is_xpu():
390432
import intel_extension_for_pytorch as ipex
391433
layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE(
@@ -536,6 +578,15 @@ def forward_cuda(
536578
expert_map=expert_map,
537579
activation=activation,
538580
apply_router_weight_on_input=apply_router_weight_on_input)
581+
elif self.flashinfer_cutlass_moe_enabled:
582+
return self.flashinfer_cutlass_moe(
583+
hidden_states=x,
584+
w1=layer.w13_weight,
585+
w2=layer.w2_weight,
586+
topk_weights=topk_weights,
587+
topk_ids=topk_ids,
588+
activation=activation,
589+
apply_router_weight_on_input=apply_router_weight_on_input)
539590
elif self.fused_experts is not None:
540591
if self.moe.has_bias:
541592
raise ValueError(

vllm/model_executor/layers/fused_moe/modular_kernel.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -598,6 +598,8 @@ def __init__(self):
598598

599599
def get(self, shape: tuple[int, ...], device: torch.device,
600600
dtype: torch.dtype):
601+
if shape == () or shape is None:
602+
return None
601603
shape_numel = prod(shape)
602604
if (self.buffer is None or self.buffer.numel() < shape_numel
603605
or self.buffer.device != device or self.buffer.dtype != dtype):

0 commit comments

Comments
 (0)