Skip to content

Commit 5efab77

Browse files
authored
Merge pull request #62 from vkuzo/20250926_mxfp4_moe_quant
enable mxfp4 quant in vllm
2 parents c011a0f + a1ffe84 commit 5efab77

File tree

1 file changed

+39
-9
lines changed

1 file changed

+39
-9
lines changed

torchao_hf_vllm/torchao_hf_script.py

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
Int8DynamicActivationInt4WeightConfig,
3131
CutlassInt4PackedLayout,
3232
)
33+
from torchao.quantization import ModuleFqnToConfig
3334
from torchao.prototype.mx_formats.inference_workflow import MXFPInferenceConfig
3435
from torchao.prototype.mx_formats import MXGemmKernelChoice
3536
from jsonargparse import CLI, Namespace
@@ -67,7 +68,6 @@ def get_quantization_config(args):
6768
case "fp8":
6869
single_config = Float8DynamicActivationFloat8WeightConfig(granularity=gran)
6970
if args.experts_only_qwen_1_5_moe_a_2_7b:
70-
from torchao.quantization import ModuleFqnToConfig
7171
expert_fqn_to_config = {}
7272
# TODO(future PR): this is annoying, I should be able to use a regex here
7373
for layer_idx in range(24):
@@ -101,14 +101,39 @@ def get_quantization_config(args):
101101
case "mxfp8":
102102
return TorchAoConfig(MXFPInferenceConfig())
103103
case "mxfp4":
104-
return TorchAoConfig(
105-
MXFPInferenceConfig(
106-
activation_dtype=torch.float4_e2m1fn_x2,
107-
weight_dtype=torch.float4_e2m1fn_x2,
108-
block_size=32,
109-
gemm_kernel_choice=MXGemmKernelChoice.CUTLASS,
110-
)
104+
single_config = MXFPInferenceConfig(
105+
activation_dtype=torch.float4_e2m1fn_x2,
106+
weight_dtype=torch.float4_e2m1fn_x2,
107+
block_size=32,
108+
# gemm_kernel_choice=MXGemmKernelChoice.CUTLASS,
109+
gemm_kernel_choice=MXGemmKernelChoice.EMULATED,
111110
)
111+
if args.experts_only_qwen_1_5_moe_a_2_7b:
112+
expert_fqn_to_config = {}
113+
# TODO(future PR): this is annoying, I should be able to use a regex here
114+
for layer_idx in range(24):
115+
for expert_idx in range(60):
116+
expert_fqn_to_config[f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.gate_proj"] = single_config
117+
expert_fqn_to_config[f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.up_proj"] = single_config
118+
expert_fqn_to_config[f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.down_proj"] = single_config
119+
module_fqn_to_config = ModuleFqnToConfig({
120+
"_default": None,
121+
**expert_fqn_to_config,
122+
})
123+
return TorchAoConfig(
124+
quant_type=module_fqn_to_config,
125+
)
126+
else:
127+
modules_to_not_convert = []
128+
if args.skip_gate_qwen_1_5_moe_a_2_7b:
129+
for layer_idx in range(24):
130+
modules_to_not_convert.append(f"model.layers.{layer_idx}.mlp.gate")
131+
modules_to_not_convert.append(f"model.layers.{layer_idx}.mlp.shared_expert_gate")
132+
modules_to_not_convert.append(f"lm_head")
133+
return TorchAoConfig(
134+
single_config,
135+
modules_to_not_convert=modules_to_not_convert,
136+
)
112137
case _:
113138
raise ValueError(f"Unsupported quantization type: {args.quant_type}")
114139

@@ -165,6 +190,7 @@ def main(
165190
bench_tokens: int = 100,
166191
device_map: str = "cuda",
167192
experts_only_qwen_1_5_moe_a_2_7b: bool = False,
193+
skip_gate_qwen_1_5_moe_a_2_7b: bool = False,
168194
save_model_to_disk: bool = True,
169195
):
170196
"""
@@ -182,6 +208,7 @@ def main(
182208
bench_tokens: Number of tokens to generate for benchmarking
183209
device_map: Device mapping strategy
184210
experts_only_qwen_1_5_moe_a_2_7b: if True, quantizes experts only for Qwen1.5-MoE-A2.7B model
211+
skip_gate_qwen_1_5_moe_a_2_7b: if True, skips gate quantization for Qwen1.5-MoE-A2.7B model
185212
save_model_to_disk: if True, saves quantized model to local disk
186213
"""
187214
# Set seed before creating the model
@@ -206,11 +233,14 @@ def main(
206233
device_map=device_map,
207234
experts_only_qwen_1_5_moe_a_2_7b=experts_only_qwen_1_5_moe_a_2_7b,
208235
save_model_to_disk=save_model_to_disk,
236+
skip_gate_qwen_1_5_moe_a_2_7b=skip_gate_qwen_1_5_moe_a_2_7b,
209237
)
210238
print(f"{args=}")
211239

212240
if args.experts_only_qwen_1_5_moe_a_2_7b:
213-
assert args.quant_type == "fp8", "unsupported"
241+
assert args.quant_type in ("fp8", "mxfp4"), "unsupported"
242+
243+
assert not args.skip_gate_qwen_1_5_moe_a_2_7b and args.experts_only_qwen_1_5_moe_a_2_7b, "unsupported"
214244

215245
# Create output directory
216246
output_dir = Path(args.output_dir)

0 commit comments

Comments
 (0)