@@ -65,9 +65,25 @@ def get_quantization_config(args):
6565 case "autoquant" :
6666 return TorchAoConfig ("autoquant" , min_sqnr = args .min_sqnr )
6767 case "fp8" :
68- return TorchAoConfig (
69- Float8DynamicActivationFloat8WeightConfig (granularity = gran )
70- )
68+ single_config = Float8DynamicActivationFloat8WeightConfig (granularity = gran )
69+ if args .experts_only_qwen_1_5_moe_a_2_7b :
70+ from torchao .quantization import ModuleFqnToConfig
71+ expert_fqn_to_config = {}
72+ # TODO(future PR): this is annoying, I should be able to use a regex here
73+ for layer_idx in range (24 ):
74+ for expert_idx in range (60 ):
75+ expert_fqn_to_config [f"model.layers.{ layer_idx } .mlp.experts.{ expert_idx } .gate_proj" ] = single_config
76+ expert_fqn_to_config [f"model.layers.{ layer_idx } .mlp.experts.{ expert_idx } .up_proj" ] = single_config
77+ expert_fqn_to_config [f"model.layers.{ layer_idx } .mlp.experts.{ expert_idx } .down_proj" ] = single_config
78+ module_fqn_to_config = ModuleFqnToConfig ({
79+ "_default" : None ,
80+ ** expert_fqn_to_config ,
81+ })
82+ return TorchAoConfig (
83+ quant_type = module_fqn_to_config ,
84+ )
85+ else :
86+ return TorchAoConfig (single_config )
7187 case "int4_weight_only" :
7288 return TorchAoConfig (Int4WeightOnlyConfig (group_size = 128 ))
7389 case "int8_weight_only" :
@@ -148,12 +164,14 @@ def main(
148164 benchmark : bool = False ,
149165 bench_tokens : int = 100 ,
150166 device_map : str = "cuda" ,
167+ experts_only_qwen_1_5_moe_a_2_7b : bool = False ,
168+ save_model_to_disk : bool = True ,
151169):
152170 """
153171 Quantize a model with TorchAO and test its performance.
154172
155173 Args:
156- model_name: Model to quantize (e.g., meta-llama/Meta-Llama-3-8B, facebook/opt-125m)
174+ model_name: Model to quantize (e.g., meta-llama/Meta-Llama-3-8B, facebook/opt-125m, Qwen/Qwen1.5-MoE-A2.7B )
157175 output_dir: Directory to save the quantized model
158176 push_to_hub: HF Hub repo name to push the model (e.g., 'your-username/model-name')
159177 quant_type: Quantization type to use
@@ -163,6 +181,8 @@ def main(
163181 benchmark: Run benchmarking comparison
164182 bench_tokens: Number of tokens to generate for benchmarking
165183 device_map: Device mapping strategy
184+ experts_only_qwen_1_5_moe_a_2_7b: if True, quantizes experts only for Qwen1.5-MoE-A2.7B model
185+ save_model_to_disk: if True, saves quantized model to local disk
166186 """
167187 # Set seed before creating the model
168188 set_seed (42 )
@@ -184,9 +204,13 @@ def main(
184204 benchmark = benchmark ,
185205 bench_tokens = bench_tokens ,
186206 device_map = device_map ,
207+ experts_only_qwen_1_5_moe_a_2_7b = experts_only_qwen_1_5_moe_a_2_7b ,
208+ save_model_to_disk = save_model_to_disk ,
187209 )
188- print (f"Using Model name: { args .model_name } " )
189- print (f"Quantization type: { args .quant_type } " )
210+ print (f"{ args = } " )
211+
212+ if args .experts_only_qwen_1_5_moe_a_2_7b :
213+ assert args .quant_type == "fp8" , "unsupported"
190214
191215 # Create output directory
192216 output_dir = Path (args .output_dir )
@@ -228,10 +252,11 @@ def main(
228252 generated_text = tokenizer .decode (output , skip_special_tokens = True )
229253 print (f"Prompt: { prompt !r} , Generated text: { generated_text !r} " )
230254
231- # Save quantized model
232- print (f"\n ðŸ“Saving quantized model to: { output_dir } " )
233- quantized_model .save_pretrained (output_dir , safe_serialization = False )
234- tokenizer .save_pretrained (output_dir )
255+ if args .save_model_to_disk :
256+ # Save quantized model
257+ print (f"\n Saving quantized model to: { output_dir } " )
258+ quantized_model .save_pretrained (output_dir , safe_serialization = False )
259+ tokenizer .save_pretrained (output_dir )
235260
236261 # Push to HuggingFace hub if requested
237262 if args .push_to_hub :
@@ -242,22 +267,24 @@ def main(
242267 quantized_model .push_to_hub (model_name , safe_serialization = False )
243268 tokenizer .push_to_hub (model_name )
244269
245- # Load saved model to verify
246- print ("\n Loading saved quantized model to verify..." )
247- # TODO: do we really need `weights_only=False` here?
248- loaded_model = AutoModelForCausalLM .from_pretrained (
249- output_dir , device_map = args .device_map , torch_dtype = "auto" , weights_only = False ,
250- )
270+ if args .save_model_to_disk :
271+ # Load saved model to verify
272+ print ("\n Loading saved quantized model to verify..." )
273+ # TODO: do we really need `weights_only=False` here?
274+ loaded_model = AutoModelForCausalLM .from_pretrained (
275+ output_dir , device_map = args .device_map , torch_dtype = "auto" , weights_only = False ,
276+ )
251277
252- # Test loaded model with first prompt
253- test_prompt = prompts [0 ]
254- input_ids = tokenizer (test_prompt , return_tensors = "pt" ).to (loaded_model .device )
255- output = loaded_model .generate (** input_ids , max_new_tokens = args .max_new_tokens )
256- generated_text = tokenizer .decode (output [0 ], skip_special_tokens = True )
257- print (f"Verification - Prompt: { test_prompt !r} , Generated text: { generated_text !r} " )
278+ # Test loaded model with first prompt
279+ test_prompt = prompts [0 ]
280+ input_ids = tokenizer (test_prompt , return_tensors = "pt" ).to (loaded_model .device )
281+ output = loaded_model .generate (** input_ids , max_new_tokens = args .max_new_tokens )
282+ generated_text = tokenizer .decode (output [0 ], skip_special_tokens = True )
283+ print (f"Verification - Prompt: { test_prompt !r} , Generated text: { generated_text !r} " )
258284
259285 # Benchmark if requested
260286 if args .benchmark :
287+ assert args .save_model_to_disk , "unsupported"
261288 print ("\n Benchmarking models..." )
262289 # Benchmark quantized model
263290 print ("Benchmarking quantized model:" )
0 commit comments