Skip to content

Commit a9c4b86

Browse files
[main][bugfix] bugfix for qwen3 moe quantization (#4599)
### What this PR does / why we need it? Fix the issue where the qwen3 moe service cannot be started due to upgrading the vllm version Error info: AttributeError: 'AscendFusedMoE' object has no attribute 'use dp chunking' ### Does this PR introduce _any_ user-facing change? no - vLLM version: v0.11.2 --------- Signed-off-by: Wang Kunpeng <1289706727@qq.com>
1 parent 12ca99c commit a9c4b86

File tree

4 files changed

+36
-10
lines changed

4 files changed

+36
-10
lines changed

examples/offline_data_parallel.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,10 @@ def parse_args():
111111
parser.add_argument("--enable-expert-parallel",
112112
action="store_true",
113113
help="Enable expert parallel, used in MOE models.")
114+
parser.add_argument("--quantization",
115+
type=str,
116+
default="",
117+
help="Use quantization models")
114118
return parser.parse_args()
115119

116120

@@ -134,6 +138,7 @@ def main(
134138
enable_expert_parallel,
135139
enforce_eager,
136140
trust_remote_code,
141+
quantization,
137142
):
138143
# DP only support on V1 engine
139144
os.environ["VLLM_DP_RANK"] = str(global_dp_rank)
@@ -185,6 +190,7 @@ def start(rank):
185190
enforce_eager=enforce_eager,
186191
enable_expert_parallel=enable_expert_parallel,
187192
trust_remote_code=trust_remote_code,
193+
quantization=quantization,
188194
)
189195
outputs = llm.generate(prompts, sampling_params)
190196
# Print the outputs.
@@ -220,6 +226,8 @@ def start(rank):
220226
assert dp_size % node_size == 0, "dp_size should be divisible by node_size"
221227
dp_per_node = dp_size // node_size
222228

229+
quantization = args.quantization if args.quantization else None
230+
223231
from multiprocessing import Process
224232

225233
procs = []
@@ -238,6 +246,7 @@ def start(rank):
238246
args.enable_expert_parallel,
239247
args.enforce_eager,
240248
args.trust_remote_code,
249+
quantization,
241250
),
242251
)
243252
proc.start()

tests/e2e/multicard/test_data_parallel.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,17 @@
2727

2828
import pytest
2929

30-
MODELS = ["Qwen/Qwen3-0.6B", "Qwen/Qwen3-30B-A3B"]
30+
MODELS = [
31+
"Qwen/Qwen3-0.6B", "Qwen/Qwen3-30B-A3B", "vllm-ascend/Qwen3-30B-A3B-W8A8"
32+
]
3133

3234

3335
@pytest.mark.parametrize("model", MODELS)
3436
@pytest.mark.parametrize("max_tokens", [32])
3537
@patch.dict(os.environ, {"ASCEND_RT_VISIBLE_DEVICES": "0,1"})
3638
def test_data_parallel_inference(model, max_tokens):
39+
moe_models = ["Qwen/Qwen3-30B-A3B", "vllm-ascend/Qwen3-30B-A3B-W8A8"]
40+
quantization_models = ["vllm-ascend/Qwen3-30B-A3B-W8A8"]
3741
script = "examples/offline_data_parallel.py"
3842

3943
env = os.environ.copy()
@@ -54,8 +58,11 @@ def test_data_parallel_inference(model, max_tokens):
5458
"--trust-remote-code",
5559
]
5660

57-
if model == "Qwen/Qwen3-30B-A3B":
61+
if model in moe_models:
5862
cmd.append("--enable-expert-parallel")
63+
if model in quantization_models:
64+
cmd.append("--quantization")
65+
cmd.append("ascend")
5966

6067
print(f"Running subprocess: {' '.join(cmd)}")
6168
proc = subprocess.run(cmd,

vllm_ascend/quantization/quant_config.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -408,11 +408,10 @@ class AscendFusedMoEMethod(FusedMoEMethodBase):
408408
quant_config: The Ascend quantization config.
409409
"""
410410

411-
def __init__(self,
412-
quant_config: AscendQuantConfig,
413-
prefix: str,
414-
packed_modules_mapping: Dict[str, Any],
415-
layer: torch.nn.Module = None):
411+
def __init__(self, quant_config: AscendQuantConfig, prefix: str,
412+
packed_modules_mapping: Dict[str,
413+
Any], layer: torch.nn.Module):
414+
super().__init__(layer.moe_config)
416415
self.quant_method = get_quant_method(quant_config.quant_description,
417416
prefix,
418417
"moe",

vllm_ascend/torchair/ops/torchair_fused_moe.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# Adapted from vllm/tests/kernels/test_moe.py
1717

1818
import os
19-
from typing import Any, Callable, Optional, Tuple, Union
19+
from typing import Any, Callable, Dict, Optional, Tuple, Union
2020

2121
import torch
2222
import torch.distributed as dist
@@ -45,7 +45,9 @@
4545
from vllm_ascend.distributed.parallel_state import get_mc2_group
4646
from vllm_ascend.eplb.core.eplb_utils import determine_default_log2phy_map
4747
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
48-
from vllm_ascend.quantization.quant_config import AscendFusedMoEMethod
48+
from vllm_ascend.quantization.quant_config import (AscendFusedMoEMethod,
49+
AscendQuantConfig)
50+
from vllm_ascend.quantization.utils import get_quant_method
4951
from vllm_ascend.torchair.ops.sequence_parallel import MetadataForPadding
5052
from vllm_ascend.torchair.utils import (get_all_reduce_merge_state,
5153
get_rm_router_logits_state,
@@ -936,6 +938,15 @@ def apply(
936938
ep_group=get_ep_group())
937939

938940

941+
class TorchairAscendFusedMoEMethod(AscendFusedMoEMethod):
942+
943+
def __init__(self, quant_config: AscendQuantConfig, prefix: str,
944+
packed_modules_mapping: Dict[str, Any]):
945+
self.quant_method = get_quant_method(quant_config.quant_description,
946+
prefix, "moe",
947+
packed_modules_mapping)
948+
949+
939950
class TorchairAscendFusedMoE(FusedMoE):
940951

941952
# The moe_counter parameter is required during the initialization of EPLB
@@ -1115,7 +1126,7 @@ def __init__(
11151126
self.quant_method = TorchairAscendUnquantizedFusedMoEMethod(
11161127
self.moe)
11171128
else:
1118-
self.quant_method = AscendFusedMoEMethod(
1129+
self.quant_method = TorchairAscendFusedMoEMethod(
11191130
quant_config, prefix, quant_config.packed_modules_mapping)
11201131

11211132
assert self.quant_method is not None

0 commit comments

Comments
 (0)