@@ -434,14 +434,9 @@ def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
434434 self .weight_block_size = self .quant_config .weight_block_size
435435 self .block_quant = self .weight_block_size is not None
436436
437- self .flashinfer_moe_backend : Optional [FlashinferMoeBackend ] = None
438437 self .fused_experts : Optional [
439438 mk .FusedMoEModularKernel ] = None # type: ignore
440- if envs .VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe ():
441- self .flashinfer_moe_backend = get_flashinfer_moe_backend ()
442- logger .info_once (
443- f"Using FlashInfer { self .flashinfer_moe_backend .value } kernels"
444- )
439+
445440 # For GPUs that lack FP8 hardware support, we can leverage the Marlin
446441 # kernel for fast weight-only FP8 quantization
447442 self .use_marlin = (not current_platform .has_device_capability (89 )
@@ -450,14 +445,27 @@ def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
450445 if current_platform .is_rocm ():
451446 self .use_marlin = False
452447
448+ # First check for Flashinfer MOE on Blackwell GPUs
449+ self .flashinfer_moe_backend : Optional [FlashinferMoeBackend ] = None
450+ if (current_platform .is_cuda ()
451+ and current_platform .is_device_capability (100 )
452+ and envs .VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe ()):
453+ self .flashinfer_moe_backend = get_flashinfer_moe_backend ()
454+ logger .info_once (
455+ f"Detected Blackwell GPUs, using FlashInfer "
456+ f"{ self .flashinfer_moe_backend .value } kernels for FP8 MOE." )
457+
453458 # Check for DeepGemm support.
454459 self .allow_deep_gemm = False
455460 if envs .VLLM_USE_DEEP_GEMM :
456461 if not has_deep_gemm ():
457462 logger .warning_once ("Failed to import DeepGemm kernels." )
458463 elif not self .block_quant :
459- logger .warning_once ("Model is not block quantized. Not using "
460- "DeepGemm kernels" )
464+ logger .warning_once ("Model is not block quantized. Not using"
465+ " DeepGemm kernels" )
466+ elif self .flashinfer_moe_backend :
467+ logger .info_once ("DeepGemm disabled: FlashInfer MOE is"
468+ " enabled." )
461469 elif (is_deep_gemm_supported ()):
462470 logger .info_once ("Using DeepGemm kernels for Fp8MoEMethod." )
463471 self .allow_deep_gemm = True
@@ -471,15 +479,12 @@ def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
471479 logger .debug_once ("Model is not block quantized. Not using "
472480 "CutlassBlockScaledGroupedGemm kernels" )
473481 elif (current_platform .is_cuda ()
474- and current_platform .is_device_capability (100 )):
482+ and current_platform .is_device_capability (100 )
483+ and not self .flashinfer_moe_backend ):
475484 logger .info_once (
476- "Using CutlassBlockScaledGroupedGemm kernels for Fp8MoEMethod. "
477- )
485+ "Using CutlassBlockScaledGroupedGemm kernels for Fp8 MOE "
486+ "on SM100." )
478487 self .allow_cutlass_block_scaled_grouped_gemm = True
479- else :
480- logger .warning_once (
481- "CutlassBlockScaledGroupedGemm not supported on the current "
482- "platform." )
483488
484489 def create_weights (self , layer : Module , num_experts : int , hidden_size : int ,
485490 intermediate_size_per_partition : int ,
@@ -934,7 +939,9 @@ def apply(
934939 import vllm .model_executor .layers .fused_moe .flashinfer_trtllm_moe # noqa: E501, F401
935940 assert (renormalize and use_grouped_topk
936941 and custom_routing_function is None )
937- result = torch .ops .vllm .flashinfer_fused_moe_blockscale_fp8 (
942+ e_score_correction_bias = (e_score_correction_bias .to (
943+ x .dtype ) if e_score_correction_bias is not None else None )
944+ return torch .ops .vllm .flashinfer_fused_moe_blockscale_fp8 (
938945 routing_logits = router_logits .to (torch .float32 ),
939946 routing_bias = e_score_correction_bias ,
940947 x = x ,
0 commit comments