3030 Int8DynamicActivationInt4WeightConfig ,
3131 CutlassInt4PackedLayout ,
3232)
33+ from torchao .quantization import ModuleFqnToConfig
3334from torchao .prototype .mx_formats .inference_workflow import MXFPInferenceConfig
3435from torchao .prototype .mx_formats import MXGemmKernelChoice
3536from jsonargparse import CLI , Namespace
@@ -67,7 +68,6 @@ def get_quantization_config(args):
6768 case "fp8" :
6869 single_config = Float8DynamicActivationFloat8WeightConfig (granularity = gran )
6970 if args .experts_only_qwen_1_5_moe_a_2_7b :
70- from torchao .quantization import ModuleFqnToConfig
7171 expert_fqn_to_config = {}
7272 # TODO(future PR): this is annoying, I should be able to use a regex here
7373 for layer_idx in range (24 ):
@@ -101,14 +101,39 @@ def get_quantization_config(args):
101101 case "mxfp8" :
102102 return TorchAoConfig (MXFPInferenceConfig ())
103103 case "mxfp4" :
104- return TorchAoConfig (
105- MXFPInferenceConfig (
106- activation_dtype = torch .float4_e2m1fn_x2 ,
107- weight_dtype = torch .float4_e2m1fn_x2 ,
108- block_size = 32 ,
109- gemm_kernel_choice = MXGemmKernelChoice .CUTLASS ,
110- )
104+ single_config = MXFPInferenceConfig (
105+ activation_dtype = torch .float4_e2m1fn_x2 ,
106+ weight_dtype = torch .float4_e2m1fn_x2 ,
107+ block_size = 32 ,
108+ # gemm_kernel_choice=MXGemmKernelChoice.CUTLASS,
109+ gemm_kernel_choice = MXGemmKernelChoice .EMULATED ,
111110 )
111+ if args .experts_only_qwen_1_5_moe_a_2_7b :
112+ expert_fqn_to_config = {}
113+ # TODO(future PR): this is annoying, I should be able to use a regex here
114+ for layer_idx in range (24 ):
115+ for expert_idx in range (60 ):
116+ expert_fqn_to_config [f"model.layers.{ layer_idx } .mlp.experts.{ expert_idx } .gate_proj" ] = single_config
117+ expert_fqn_to_config [f"model.layers.{ layer_idx } .mlp.experts.{ expert_idx } .up_proj" ] = single_config
118+ expert_fqn_to_config [f"model.layers.{ layer_idx } .mlp.experts.{ expert_idx } .down_proj" ] = single_config
119+ module_fqn_to_config = ModuleFqnToConfig ({
120+ "_default" : None ,
121+ ** expert_fqn_to_config ,
122+ })
123+ return TorchAoConfig (
124+ quant_type = module_fqn_to_config ,
125+ )
126+ else :
127+ modules_to_not_convert = []
128+ if args .skip_gate_qwen_1_5_moe_a_2_7b :
129+ for layer_idx in range (24 ):
130+ modules_to_not_convert .append (f"model.layers.{ layer_idx } .mlp.gate" )
131+ modules_to_not_convert .append (f"model.layers.{ layer_idx } .mlp.shared_expert_gate" )
132+ modules_to_not_convert .append (f"lm_head" )
133+ return TorchAoConfig (
134+ single_config ,
135+ modules_to_not_convert = modules_to_not_convert ,
136+ )
112137 case _:
113138 raise ValueError (f"Unsupported quantization type: { args .quant_type } " )
114139
@@ -165,6 +190,7 @@ def main(
165190 bench_tokens : int = 100 ,
166191 device_map : str = "cuda" ,
167192 experts_only_qwen_1_5_moe_a_2_7b : bool = False ,
193+ skip_gate_qwen_1_5_moe_a_2_7b : bool = False ,
168194 save_model_to_disk : bool = True ,
169195):
170196 """
@@ -182,6 +208,7 @@ def main(
182208 bench_tokens: Number of tokens to generate for benchmarking
183209 device_map: Device mapping strategy
184210 experts_only_qwen_1_5_moe_a_2_7b: if True, quantizes experts only for Qwen1.5-MoE-A2.7B model
211+ skip_gate_qwen_1_5_moe_a_2_7b: if True, skips gate quantization for Qwen1.5-MoE-A2.7B model
185212 save_model_to_disk: if True, saves quantized model to local disk
186213 """
187214 # Set seed before creating the model
@@ -206,11 +233,14 @@ def main(
206233 device_map = device_map ,
207234 experts_only_qwen_1_5_moe_a_2_7b = experts_only_qwen_1_5_moe_a_2_7b ,
208235 save_model_to_disk = save_model_to_disk ,
236+ skip_gate_qwen_1_5_moe_a_2_7b = skip_gate_qwen_1_5_moe_a_2_7b ,
209237 )
210238 print (f"{ args = } " )
211239
212240 if args .experts_only_qwen_1_5_moe_a_2_7b :
213- assert args .quant_type == "fp8" , "unsupported"
241+ assert args .quant_type in ("fp8" , "mxfp4" ), "unsupported"
242+
243+ assert not args .skip_gate_qwen_1_5_moe_a_2_7b and args .experts_only_qwen_1_5_moe_a_2_7b , "unsupported"
214244
215245 # Create output directory
216246 output_dir = Path (args .output_dir )
0 commit comments