|
31 | 31 | CutlassInt4PackedLayout, |
32 | 32 | ) |
33 | 33 | from torchao.quantization import ModuleFqnToConfig |
34 | | -from torchao.prototype.mx_formats.inference_workflow import MXFPInferenceConfig |
| 34 | +from torchao.prototype.mx_formats.inference_workflow import ( |
| 35 | + MXFPInferenceConfig, |
| 36 | + NVFP4InferenceConfig, |
| 37 | + NVFP4MMConfig, |
| 38 | +) |
35 | 39 | from torchao.prototype.mx_formats import MXGemmKernelChoice |
36 | 40 | from jsonargparse import CLI, Namespace |
37 | 41 | from rich import print |
@@ -134,6 +138,38 @@ def get_quantization_config(args): |
134 | 138 | single_config, |
135 | 139 | modules_to_not_convert=modules_to_not_convert, |
136 | 140 | ) |
| 141 | + case "nvfp4": |
| 142 | + single_config = NVFP4InferenceConfig( |
| 143 | + mm_config=NVFP4MMConfig.WEIGHT_ONLY, |
| 144 | + use_triton_kernel=False, |
| 145 | + use_dynamic_per_tensor_scale=False, |
| 146 | + ) |
| 147 | + if args.experts_only_qwen_1_5_moe_a_2_7b: |
| 148 | + expert_fqn_to_config = {} |
| 149 | + # TODO(future PR): this is annoying, I should be able to use a regex here |
| 150 | + for layer_idx in range(24): |
| 151 | + for expert_idx in range(60): |
| 152 | + expert_fqn_to_config[f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.gate_proj"] = single_config |
| 153 | + expert_fqn_to_config[f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.up_proj"] = single_config |
| 154 | + expert_fqn_to_config[f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.down_proj"] = single_config |
| 155 | + module_fqn_to_config = ModuleFqnToConfig({ |
| 156 | + "_default": None, |
| 157 | + **expert_fqn_to_config, |
| 158 | + }) |
| 159 | + return TorchAoConfig( |
| 160 | + quant_type=module_fqn_to_config, |
| 161 | + ) |
| 162 | + else: |
| 163 | + modules_to_not_convert = [] |
| 164 | + if args.skip_gate_qwen_1_5_moe_a_2_7b: |
| 165 | + for layer_idx in range(24): |
| 166 | + modules_to_not_convert.append(f"model.layers.{layer_idx}.mlp.gate") |
| 167 | + modules_to_not_convert.append(f"model.layers.{layer_idx}.mlp.shared_expert_gate") |
| 168 | + modules_to_not_convert.append(f"lm_head") |
| 169 | + return TorchAoConfig( |
| 170 | + single_config, |
| 171 | + modules_to_not_convert=modules_to_not_convert, |
| 172 | + ) |
137 | 173 | case _: |
138 | 174 | raise ValueError(f"Unsupported quantization type: {args.quant_type}") |
139 | 175 |
|
@@ -182,6 +218,7 @@ def main( |
182 | 218 | "A8W4", |
183 | 219 | "fp8", |
184 | 220 | "mxfp4", |
| 221 | + "nvfp4", |
185 | 222 | ] = "fp8", |
186 | 223 | granularity: Literal["per_row", "per_tensor"] = "per_row", |
187 | 224 | min_sqnr: Optional[float] = None, |
@@ -238,9 +275,9 @@ def main( |
238 | 275 | print(f"{args=}") |
239 | 276 |
|
240 | 277 | if args.experts_only_qwen_1_5_moe_a_2_7b: |
241 | | - assert args.quant_type in ("fp8", "mxfp4"), "unsupported" |
| 278 | + assert args.quant_type in ("fp8", "mxfp4", "nvfp4"), "unsupported" |
242 | 279 |
|
243 | | - assert not args.skip_gate_qwen_1_5_moe_a_2_7b and args.experts_only_qwen_1_5_moe_a_2_7b, "unsupported" |
| 280 | + assert not (args.skip_gate_qwen_1_5_moe_a_2_7b and args.experts_only_qwen_1_5_moe_a_2_7b), "unsupported" |
244 | 281 |
|
245 | 282 | # Create output directory |
246 | 283 | output_dir = Path(args.output_dir) |
|
0 commit comments