@@ -83,6 +83,7 @@ def get_quantization_config(args):
8383 # TODO tool to find this (I used bisect on this tiny model).
8484 activation_value_lb = 1.0e-12 ,
8585 )
86+
8687 if args .experts_only_qwen_1_5_moe_a_2_7b :
8788 expert_fqn_to_config = {}
8889 # TODO(future PR): this is annoying, I should be able to use a regex here
@@ -125,6 +126,32 @@ def get_quantization_config(args):
125126 return TorchAoConfig (
126127 quant_type = module_fqn_to_config ,
127128 )
129+ elif args .ffn_only_llama_4_scout :
130+ # TODO gate this properly
131+ expert_3d_weight_single_config = Float8DynamicActivationFloat8WeightConfig (
132+ # the weights of this model are stored in (B, K, N) layout, and we
133+ # need to quantize rowwise across the K axis, which is `PerRow(1)`.
134+ granularity = [PerRow (), PerRow (1 )],
135+ # the 125m model has a lot of activation zeroes for some
136+ # prompts, need to set a lower bound to prevent scales from
137+ # being 0.
138+ # TODO seems like torchao should do this for me.
139+ # TODO tool to find this (I used bisect on this tiny model).
140+ activation_value_lb = 1.0e-12 ,
141+ )
142+ module_fqn_to_config = ModuleFqnToConfig (
143+ {
144+ r"re:.*\.feed_forward\.experts\.gate_up_proj" : expert_3d_weight_single_config ,
145+ r"re:.*\.feed_forward\.experts\.down_proj" : expert_3d_weight_single_config ,
146+ r"re:.*\.shared_expert\.down_proj" : single_config ,
147+ r"re:.*\.shared_expert\.up_proj" : single_config ,
148+ r"re:.*\.shared_expert\.gate_proj" : single_config ,
149+ }
150+ )
151+ return TorchAoConfig (
152+ quant_type = module_fqn_to_config ,
153+ )
154+
128155 else :
129156 return TorchAoConfig (single_config )
130157 case "int4_weight_only" :
@@ -318,6 +345,44 @@ def benchmark_model(model, input_ids, max_new_tokens, name=""):
318345 return elapsed
319346
320347
348+ def _inference_with_processor (
349+ model ,
350+ processor ,
351+ prompts ,
352+ args ,
353+ ) -> None :
354+ messages = []
355+ for prompt in prompts [:1 ]:
356+ messages .append (
357+ {
358+ "role" : "user" ,
359+ "content" : [
360+ {"type" : "text" , "text" : prompt },
361+ ],
362+ },
363+ )
364+
365+ inputs = processor .apply_chat_template (
366+ messages ,
367+ add_generation_prompt = True ,
368+ tokenize = True ,
369+ return_dict = True ,
370+ return_tensors = "pt" ,
371+ padding = "longest" ,
372+ ).to (model .device )
373+
374+ outputs = model .generate (
375+ ** inputs ,
376+ max_new_tokens = args .max_new_tokens ,
377+ )
378+
379+ responses = processor .batch_decode (
380+ outputs [:, inputs ["input_ids" ].shape [- 1 ] :]
381+ )
382+ for response in responses :
383+ print (response )
384+
385+
321386def main (
322387 model_name : str = "facebook/opt-125m" ,
323388 output_dir : str | None = None ,
@@ -345,6 +410,7 @@ def main(
345410 device_map : str = "cuda" ,
346411 experts_only_qwen_1_5_moe_a_2_7b : bool = False ,
347412 skip_gate_qwen_1_5_moe_a_2_7b : bool = False ,
413+ ffn_only_llama_4_scout : bool = False ,
348414 save_model_to_disk : bool = True ,
349415):
350416 """
@@ -363,6 +429,7 @@ def main(
363429 device_map: Device mapping strategy
364430 experts_only_qwen_1_5_moe_a_2_7b: if True, quantizes experts only for Qwen1.5-MoE-A2.7B model
365431 skip_gate_qwen_1_5_moe_a_2_7b: if True, skips gate quantization for Qwen1.5-MoE-A2.7B model
432+ ffn_only_llama_4_scout: if True, FFN only for meta-llama/Llama-4-Scout-17B-16E-Instruct
366433 save_model_to_disk: if True, saves quantized model to local disk
367434 """
368435 # Test prompts
@@ -397,6 +464,7 @@ def main(
397464 experts_only_qwen_1_5_moe_a_2_7b = experts_only_qwen_1_5_moe_a_2_7b ,
398465 save_model_to_disk = save_model_to_disk ,
399466 skip_gate_qwen_1_5_moe_a_2_7b = skip_gate_qwen_1_5_moe_a_2_7b ,
467+ ffn_only_llama_4_scout = ffn_only_llama_4_scout ,
400468 )
401469 print (f"{ args = } " )
402470
@@ -430,37 +498,26 @@ def main(
430498 torch_dtype = torch .bfloat16 ,
431499 quantization_config = quantization_config ,
432500 )
501+ print (quantized_model )
433502
434- messages = []
435- for prompt in prompts [:1 ]:
436- messages .append (
437- {
438- "role" : "user" ,
439- "content" : [
440- {"type" : "text" , "text" : prompt },
441- ],
442- },
443- )
444-
445- inputs = processor .apply_chat_template (
446- messages ,
447- add_generation_prompt = True ,
448- tokenize = True ,
449- return_dict = True ,
450- return_tensors = "pt" ,
451- padding = "longest" ,
452- ).to (quantized_model .device )
453-
454- outputs = quantized_model .generate (
455- ** inputs ,
456- max_new_tokens = args .max_new_tokens ,
503+ print (
504+ "quantized_model.language_model.model.layers[47].feed_forward.experts.down_proj" ,
505+ type (
506+ quantized_model .language_model .model .layers [
507+ 47
508+ ].feed_forward .experts .down_proj
509+ ),
457510 )
458-
459- responses = processor .batch_decode (
460- outputs [:, inputs ["input_ids" ].shape [- 1 ] :]
511+ print (
512+ "quantized_model.language_model.model.layers[47].feed_forward.experts.gate_up_proj" ,
513+ type (
514+ quantized_model .language_model .model .layers [
515+ 47
516+ ].feed_forward .experts .gate_up_proj
517+ ),
461518 )
462- for response in responses :
463- print ( response )
519+
520+ _inference_with_processor ( quantized_model , processor , prompts , args )
464521
465522 else :
466523 # Load tokenizer
@@ -495,8 +552,16 @@ def main(
495552 if args .save_model_to_disk :
496553 # Save quantized model
497554 print (f"\n Saving quantized model to: { output_dir } " )
498- quantized_model .save_pretrained (output_dir , safe_serialization = False )
499- tokenizer .save_pretrained (output_dir )
555+ if args .model_name == "meta-llama/Llama-4-Scout-17B-16E-Instruct" :
556+ quantized_model .save_pretrained (
557+ output_dir , safe_serialization = False
558+ )
559+ processor .save_pretrained (output_dir )
560+ else :
561+ quantized_model .save_pretrained (
562+ output_dir , safe_serialization = False
563+ )
564+ tokenizer .save_pretrained (output_dir )
500565
501566 # Push to HuggingFace hub if requested
502567 if args .push_to_hub :
@@ -509,28 +574,35 @@ def main(
509574
510575 if args .save_model_to_disk :
511576 # TODO(future): support this for LLaMa 4 Scout
512- # Load saved model to verify
513577 print ("\n Loading saved quantized model to verify..." )
514- # TODO: do we really need `weights_only=False` here?
515- loaded_model = AutoModelForCausalLM .from_pretrained (
516- output_dir ,
517- device_map = args .device_map ,
518- torch_dtype = "auto" ,
519- weights_only = False ,
520- )
578+ if args .model_name == "meta-llama/Llama-4-Scout-17B-16E-Instruct" :
579+ _inference_with_processor (quantized_model , processor , prompts , args )
580+ print ("Verified inference with reloaded model" )
581+
582+ else :
583+ # Load saved model to verify
584+ # TODO: do we really need `weights_only=False` here?
585+ loaded_model = AutoModelForCausalLM .from_pretrained (
586+ output_dir ,
587+ device_map = args .device_map ,
588+ torch_dtype = "auto" ,
589+ weights_only = False ,
590+ )
521591
522- # Test loaded model with first prompt
523- test_prompt = prompts [0 ]
524- input_ids = tokenizer (test_prompt , return_tensors = "pt" ).to (
525- loaded_model .device
526- )
527- output = loaded_model .generate (
528- ** input_ids , max_new_tokens = args .max_new_tokens
529- )
530- generated_text = tokenizer .decode (output [0 ], skip_special_tokens = True )
531- print (
532- f"Verification - Prompt: { test_prompt !r} , Generated text: { generated_text !r} "
533- )
592+ # Test loaded model with first prompt
593+ test_prompt = prompts [0 ]
594+ input_ids = tokenizer (test_prompt , return_tensors = "pt" ).to (
595+ loaded_model .device
596+ )
597+ output = loaded_model .generate (
598+ ** input_ids , max_new_tokens = args .max_new_tokens
599+ )
600+ generated_text = tokenizer .decode (
601+ output [0 ], skip_special_tokens = True
602+ )
603+ print (
604+ f"Verification - Prompt: { test_prompt !r} , Generated text: { generated_text !r} "
605+ )
534606
535607 # Benchmark if requested
536608 if args .benchmark :
0 commit comments