2323from vllm .model_executor .layers .fused_moe .fused_marlin_moe import fused_marlin_moe
2424from vllm .model_executor .layers .fused_moe .rocm_aiter_fused_moe import (
2525 is_rocm_aiter_moe_enabled ,
26+ use_mxfp4_aiter_moe ,
2627)
2728from vllm .model_executor .layers .quantization .utils .marlin_utils_fp8 import (
2829 prepare_moe_fp8_layer_for_marlin ,
@@ -472,22 +473,22 @@ def __init__(
472473 "not implemented. Please open an issue."
473474 )
474475
475- if not current_platform .supports_mx ():
476- self .emulate = True
476+ self .emulate = not current_platform .supports_mx () or not (
477+ use_mxfp4_aiter_moe () and self .ocp_mx_scheme == "w_mxfp4_a_mxfp4"
478+ )
479+ if self .emulate :
477480 logger .warning_once (
478- "The current platform does not support native MXFP4/MXFP6 "
481+ f"The current mode (supports_mx={ current_platform .supports_mx ()} , "
482+ f"use_mxfp4_aiter_moe={ use_mxfp4_aiter_moe ()} , "
483+ f"ocp_mx_scheme={ self .ocp_mx_scheme } ) "
484+ "does not support native MXFP4/MXFP6 "
479485 "computation. Simulated weight dequantization and activation "
480486 "QDQ (quantize and dequantize) will be used, with the linear "
481487 "layers computed in high precision."
482488 )
483489 else :
484- self .emulate = True
485490 logger .warning_once (
486- "The current platform supports native MXFP4/MXFP6 "
487- "computation, but kernels are not yet integrated in vLLM. "
488- "Simulated weight dequantization and activation "
489- "QDQ (quantize and dequantize) will be used, with the linear "
490- "layers computed in high precision."
491+ "The current mode supports native MoE MXFP4 computation"
491492 )
492493
493494 def get_packed_dim (self , dim : int , quant_dtype : str ):
@@ -568,6 +569,24 @@ def create_weights(
568569 layer .register_parameter ("w13_weight_scale" , w13_weight_scale )
569570 layer .register_parameter ("w2_weight_scale" , w2_weight_scale )
570571
572+ def process_weights_after_loading (self , layer ):
573+ if self .emulate :
574+ return
575+
576+ from aiter .utility .fp4_utils import e8m0_shuffle
577+
578+ # Pre-shuffle weight scales
579+ s0 , s1 , _ = layer .w13_weight_scale .shape
580+ w13_weight_scale = layer .w13_weight_scale .view (s0 * s1 , - 1 )
581+ w13_weight_scale = e8m0_shuffle (w13_weight_scale )
582+ layer .w13_weight_scale .data = w13_weight_scale .view (s0 , s1 , - 1 )
583+
584+ s0 , s1 , _ = layer .w2_weight_scale .shape
585+ w2_weight_scale = layer .w2_weight_scale .view (s0 * s1 , - 1 )
586+ w2_weight_scale = e8m0_shuffle (w2_weight_scale )
587+ layer .w2_weight_scale .data = w2_weight_scale .view (s0 , s1 , - 1 )
588+ torch .cuda .empty_cache ()
589+
571590 def get_fused_moe_quant_config (
572591 self , layer : torch .nn .Module
573592 ) -> FusedMoEQuantConfig | None :
@@ -611,8 +630,6 @@ def apply(
611630 "EPLB not supported for `QuarkOCP_MX_MoEMethod` yet."
612631 )
613632
614- from vllm .model_executor .layers .fused_moe import fused_experts
615-
616633 topk_weights , topk_ids , _ = FusedMoE .select_experts (
617634 hidden_states = x ,
618635 router_logits = router_logits ,
@@ -628,17 +645,44 @@ def apply(
628645 indices_type = self .topk_indices_dtype ,
629646 )
630647
631- out = fused_experts (
632- x ,
633- layer .w13_weight ,
634- layer .w2_weight ,
635- topk_weights = topk_weights ,
636- topk_ids = topk_ids ,
637- inplace = True ,
638- activation = activation ,
639- global_num_experts = global_num_experts ,
640- apply_router_weight_on_input = apply_router_weight_on_input ,
641- expert_map = expert_map ,
642- quant_config = self .moe_quant_config ,
643- )
648+ if not self .emulate :
649+ from aiter import ActivationType , QuantType
650+ from aiter .fused_moe import fused_moe
651+
652+ aiter_acts = {
653+ ActivationType .No .name .lower (): ActivationType .No ,
654+ ActivationType .Silu .name .lower (): ActivationType .Silu ,
655+ ActivationType .Gelu .name .lower (): ActivationType .Gelu ,
656+ }
657+ assert activation in aiter_acts , (
658+ f"Aiter CK fp4 MoE doesn't support activation { activation } "
659+ )
660+ out = fused_moe (
661+ x ,
662+ layer .w13_weight ,
663+ layer .w2_weight ,
664+ topk_weights ,
665+ topk_ids ,
666+ quant_type = QuantType .per_1x32 ,
667+ w1_scale = layer .w13_weight_scale ,
668+ w2_scale = layer .w2_weight_scale ,
669+ activation = aiter_acts [activation ],
670+ doweight_stage1 = False ,
671+ )
672+ else :
673+ from vllm .model_executor .layers .fused_moe import fused_experts
674+
675+ out = fused_experts (
676+ x ,
677+ layer .w13_weight ,
678+ layer .w2_weight ,
679+ topk_weights = topk_weights ,
680+ topk_ids = topk_ids ,
681+ inplace = True ,
682+ activation = activation ,
683+ global_num_experts = global_num_experts ,
684+ apply_router_weight_on_input = apply_router_weight_on_input ,
685+ expert_map = expert_map ,
686+ quant_config = self .moe_quant_config ,
687+ )
644688 return out
0 commit comments