Skip to content

Commit 015c7a6

Browse files
maleksan85Aleksandr Malyshev
authored andcommitted
[ROCM] MoE fp4 CK kernel (vllm-project#26545)
Signed-off-by: Aleksandr Malyshev <maleksan@amd.com> Co-authored-by: Aleksandr Malyshev <maleksan@amd.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
1 parent b8692df commit 015c7a6

File tree

2 files changed

+73
-24
lines changed

2 files changed

+73
-24
lines changed

vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,11 @@ def is_rocm_aiter_moe_enabled() -> bool:
4646
)
4747

4848

49+
@cache
50+
def use_mxfp4_aiter_moe() -> bool:
51+
return current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER
52+
53+
4954
@cache
5055
def is_rocm_aiter_fusion_shared_expert_enabled() -> bool:
5156
return (

vllm/model_executor/layers/quantization/quark/quark_moe.py

Lines changed: 68 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
2424
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
2525
is_rocm_aiter_moe_enabled,
26+
use_mxfp4_aiter_moe,
2627
)
2728
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
2829
prepare_moe_fp8_layer_for_marlin,
@@ -472,22 +473,22 @@ def __init__(
472473
"not implemented. Please open an issue."
473474
)
474475

475-
if not current_platform.supports_mx():
476-
self.emulate = True
476+
self.emulate = not current_platform.supports_mx() or not (
477+
use_mxfp4_aiter_moe() and self.ocp_mx_scheme == "w_mxfp4_a_mxfp4"
478+
)
479+
if self.emulate:
477480
logger.warning_once(
478-
"The current platform does not support native MXFP4/MXFP6 "
481+
f"The current mode (supports_mx={current_platform.supports_mx()}, "
482+
f"use_mxfp4_aiter_moe={use_mxfp4_aiter_moe()}, "
483+
f"ocp_mx_scheme={self.ocp_mx_scheme}) "
484+
"does not support native MXFP4/MXFP6 "
479485
"computation. Simulated weight dequantization and activation "
480486
"QDQ (quantize and dequantize) will be used, with the linear "
481487
"layers computed in high precision."
482488
)
483489
else:
484-
self.emulate = True
485490
logger.warning_once(
486-
"The current platform supports native MXFP4/MXFP6 "
487-
"computation, but kernels are not yet integrated in vLLM. "
488-
"Simulated weight dequantization and activation "
489-
"QDQ (quantize and dequantize) will be used, with the linear "
490-
"layers computed in high precision."
491+
"The current mode supports native MoE MXFP4 computation"
491492
)
492493

493494
def get_packed_dim(self, dim: int, quant_dtype: str):
@@ -568,6 +569,24 @@ def create_weights(
568569
layer.register_parameter("w13_weight_scale", w13_weight_scale)
569570
layer.register_parameter("w2_weight_scale", w2_weight_scale)
570571

572+
def process_weights_after_loading(self, layer):
573+
if self.emulate:
574+
return
575+
576+
from aiter.utility.fp4_utils import e8m0_shuffle
577+
578+
# Pre-shuffle weight scales
579+
s0, s1, _ = layer.w13_weight_scale.shape
580+
w13_weight_scale = layer.w13_weight_scale.view(s0 * s1, -1)
581+
w13_weight_scale = e8m0_shuffle(w13_weight_scale)
582+
layer.w13_weight_scale.data = w13_weight_scale.view(s0, s1, -1)
583+
584+
s0, s1, _ = layer.w2_weight_scale.shape
585+
w2_weight_scale = layer.w2_weight_scale.view(s0 * s1, -1)
586+
w2_weight_scale = e8m0_shuffle(w2_weight_scale)
587+
layer.w2_weight_scale.data = w2_weight_scale.view(s0, s1, -1)
588+
torch.cuda.empty_cache()
589+
571590
def get_fused_moe_quant_config(
572591
self, layer: torch.nn.Module
573592
) -> FusedMoEQuantConfig | None:
@@ -611,8 +630,6 @@ def apply(
611630
"EPLB not supported for `QuarkOCP_MX_MoEMethod` yet."
612631
)
613632

614-
from vllm.model_executor.layers.fused_moe import fused_experts
615-
616633
topk_weights, topk_ids, _ = FusedMoE.select_experts(
617634
hidden_states=x,
618635
router_logits=router_logits,
@@ -628,17 +645,44 @@ def apply(
628645
indices_type=self.topk_indices_dtype,
629646
)
630647

631-
out = fused_experts(
632-
x,
633-
layer.w13_weight,
634-
layer.w2_weight,
635-
topk_weights=topk_weights,
636-
topk_ids=topk_ids,
637-
inplace=True,
638-
activation=activation,
639-
global_num_experts=global_num_experts,
640-
apply_router_weight_on_input=apply_router_weight_on_input,
641-
expert_map=expert_map,
642-
quant_config=self.moe_quant_config,
643-
)
648+
if not self.emulate:
649+
from aiter import ActivationType, QuantType
650+
from aiter.fused_moe import fused_moe
651+
652+
aiter_acts = {
653+
ActivationType.No.name.lower(): ActivationType.No,
654+
ActivationType.Silu.name.lower(): ActivationType.Silu,
655+
ActivationType.Gelu.name.lower(): ActivationType.Gelu,
656+
}
657+
assert activation in aiter_acts, (
658+
f"Aiter CK fp4 MoE doesn't support activation {activation}"
659+
)
660+
out = fused_moe(
661+
x,
662+
layer.w13_weight,
663+
layer.w2_weight,
664+
topk_weights,
665+
topk_ids,
666+
quant_type=QuantType.per_1x32,
667+
w1_scale=layer.w13_weight_scale,
668+
w2_scale=layer.w2_weight_scale,
669+
activation=aiter_acts[activation],
670+
doweight_stage1=False,
671+
)
672+
else:
673+
from vllm.model_executor.layers.fused_moe import fused_experts
674+
675+
out = fused_experts(
676+
x,
677+
layer.w13_weight,
678+
layer.w2_weight,
679+
topk_weights=topk_weights,
680+
topk_ids=topk_ids,
681+
inplace=True,
682+
activation=activation,
683+
global_num_experts=global_num_experts,
684+
apply_router_weight_on_input=apply_router_weight_on_input,
685+
expert_map=expert_map,
686+
quant_config=self.moe_quant_config,
687+
)
644688
return out

0 commit comments

Comments
 (0)