Skip to content

Commit 7cd115e

Browse files
authored
Merge pull request #63 from vkuzo/20250929_update_nvfp4
add more nvfp4 handling
2 parents 5efab77 + f6eb5dc commit 7cd115e

File tree

1 file changed

+40
-3
lines changed

1 file changed

+40
-3
lines changed

torchao_hf_vllm/torchao_hf_script.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,11 @@
3131
CutlassInt4PackedLayout,
3232
)
3333
from torchao.quantization import ModuleFqnToConfig
34-
from torchao.prototype.mx_formats.inference_workflow import MXFPInferenceConfig
34+
from torchao.prototype.mx_formats.inference_workflow import (
35+
MXFPInferenceConfig,
36+
NVFP4InferenceConfig,
37+
NVFP4MMConfig,
38+
)
3539
from torchao.prototype.mx_formats import MXGemmKernelChoice
3640
from jsonargparse import CLI, Namespace
3741
from rich import print
@@ -134,6 +138,38 @@ def get_quantization_config(args):
134138
single_config,
135139
modules_to_not_convert=modules_to_not_convert,
136140
)
141+
case "nvfp4":
142+
single_config = NVFP4InferenceConfig(
143+
mm_config=NVFP4MMConfig.WEIGHT_ONLY,
144+
use_triton_kernel=False,
145+
use_dynamic_per_tensor_scale=False,
146+
)
147+
if args.experts_only_qwen_1_5_moe_a_2_7b:
148+
expert_fqn_to_config = {}
149+
# TODO(future PR): this is annoying, I should be able to use a regex here
150+
for layer_idx in range(24):
151+
for expert_idx in range(60):
152+
expert_fqn_to_config[f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.gate_proj"] = single_config
153+
expert_fqn_to_config[f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.up_proj"] = single_config
154+
expert_fqn_to_config[f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.down_proj"] = single_config
155+
module_fqn_to_config = ModuleFqnToConfig({
156+
"_default": None,
157+
**expert_fqn_to_config,
158+
})
159+
return TorchAoConfig(
160+
quant_type=module_fqn_to_config,
161+
)
162+
else:
163+
modules_to_not_convert = []
164+
if args.skip_gate_qwen_1_5_moe_a_2_7b:
165+
for layer_idx in range(24):
166+
modules_to_not_convert.append(f"model.layers.{layer_idx}.mlp.gate")
167+
modules_to_not_convert.append(f"model.layers.{layer_idx}.mlp.shared_expert_gate")
168+
modules_to_not_convert.append(f"lm_head")
169+
return TorchAoConfig(
170+
single_config,
171+
modules_to_not_convert=modules_to_not_convert,
172+
)
137173
case _:
138174
raise ValueError(f"Unsupported quantization type: {args.quant_type}")
139175

@@ -182,6 +218,7 @@ def main(
182218
"A8W4",
183219
"fp8",
184220
"mxfp4",
221+
"nvfp4",
185222
] = "fp8",
186223
granularity: Literal["per_row", "per_tensor"] = "per_row",
187224
min_sqnr: Optional[float] = None,
@@ -238,9 +275,9 @@ def main(
238275
print(f"{args=}")
239276

240277
if args.experts_only_qwen_1_5_moe_a_2_7b:
241-
assert args.quant_type in ("fp8", "mxfp4"), "unsupported"
278+
assert args.quant_type in ("fp8", "mxfp4", "nvfp4"), "unsupported"
242279

243-
assert not args.skip_gate_qwen_1_5_moe_a_2_7b and args.experts_only_qwen_1_5_moe_a_2_7b, "unsupported"
280+
assert not (args.skip_gate_qwen_1_5_moe_a_2_7b and args.experts_only_qwen_1_5_moe_a_2_7b), "unsupported"
244281

245282
# Create output directory
246283
output_dir = Path(args.output_dir)

0 commit comments

Comments
 (0)