@@ -83,6 +83,7 @@ def get_quantization_config(args):
8383 # TODO tool to find this (I used bisect on this tiny model).
8484 activation_value_lb = 1.0e-12 ,
8585 )
86+
8687 if args .experts_only_qwen_1_5_moe_a_2_7b :
8788 expert_fqn_to_config = {}
8889 # TODO(future PR): this is annoying, I should be able to use a regex here
@@ -125,6 +126,32 @@ def get_quantization_config(args):
125126 return TorchAoConfig (
126127 quant_type = module_fqn_to_config ,
127128 )
129+ elif args .ffn_only_llama_4_scout :
130+ # TODO gate this properly
131+ expert_3d_weight_single_config = Float8DynamicActivationFloat8WeightConfig (
132+ # the weights of this model are stored in (B, K, N) layout, and we
133+ # need to quantize rowwise across the K axis, which is `PerRow(1)`.
134+ granularity = [PerRow (), PerRow (1 )],
135+ # the 125m model has a lot of activation zeroes for some
136+ # prompts, need to set a lower bound to prevent scales from
137+ # being 0.
138+ # TODO seems like torchao should do this for me.
139+ # TODO tool to find this (I used bisect on this tiny model).
140+ activation_value_lb = 1.0e-12 ,
141+ )
142+ module_fqn_to_config = ModuleFqnToConfig (
143+ {
144+ r"re:.*\.feed_forward\.experts\.gate_up_proj" : expert_3d_weight_single_config ,
145+ r"re:.*\.feed_forward\.experts\.down_proj" : expert_3d_weight_single_config ,
146+ r"re:.*\.shared_expert\.down_proj" : single_config ,
147+ r"re:.*\.shared_expert\.up_proj" : single_config ,
148+ r"re:.*\.shared_expert\.gate_proj" : single_config ,
149+ }
150+ )
151+ return TorchAoConfig (
152+ quant_type = module_fqn_to_config ,
153+ )
154+
128155 else :
129156 return TorchAoConfig (single_config )
130157 case "int4_weight_only" :
@@ -345,6 +372,7 @@ def main(
345372 device_map : str = "cuda" ,
346373 experts_only_qwen_1_5_moe_a_2_7b : bool = False ,
347374 skip_gate_qwen_1_5_moe_a_2_7b : bool = False ,
375+ ffn_only_llama_4_scout : bool = False ,
348376 save_model_to_disk : bool = True ,
349377):
350378 """
@@ -363,6 +391,7 @@ def main(
363391 device_map: Device mapping strategy
364392 experts_only_qwen_1_5_moe_a_2_7b: if True, quantizes experts only for Qwen1.5-MoE-A2.7B model
365393 skip_gate_qwen_1_5_moe_a_2_7b: if True, skips gate quantization for Qwen1.5-MoE-A2.7B model
394+ ffn_only_llama_4_scout: if True, FFN only for meta-llama/Llama-4-Scout-17B-16E-Instruct
366395 save_model_to_disk: if True, saves quantized model to local disk
367396 """
368397 # Test prompts
@@ -397,6 +426,7 @@ def main(
397426 experts_only_qwen_1_5_moe_a_2_7b = experts_only_qwen_1_5_moe_a_2_7b ,
398427 save_model_to_disk = save_model_to_disk ,
399428 skip_gate_qwen_1_5_moe_a_2_7b = skip_gate_qwen_1_5_moe_a_2_7b ,
429+ ffn_only_llama_4_scout = ffn_only_llama_4_scout ,
400430 )
401431 print (f"{ args = } " )
402432
@@ -415,6 +445,7 @@ def main(
415445 # Get quantization config
416446 quantization_config = get_quantization_config (args )
417447
448+ # TODO(before land): clean up the chat processor code
418449 if args .model_name == "meta-llama/Llama-4-Scout-17B-16E-Instruct" :
419450 # TODO(future): maybe unify with the else branch, need to figure
420451 # out the right syntax for preparing inputs and running generation
@@ -430,6 +461,27 @@ def main(
430461 torch_dtype = torch .bfloat16 ,
431462 quantization_config = quantization_config ,
432463 )
464+ print (quantized_model )
465+
466+ print (
467+ "quantized_model.language_model.model.layers[47].feed_forward.experts.down_proj" ,
468+ type (
469+ quantized_model .language_model .model .layers [
470+ 47
471+ ].feed_forward .experts .down_proj
472+ ),
473+ )
474+ print (
475+ "quantized_model.language_model.model.layers[47].feed_forward.experts.gate_up_proj" ,
476+ type (
477+ quantized_model .language_model .model .layers [
478+ 47
479+ ].feed_forward .experts .gate_up_proj
480+ ),
481+ )
482+
483+ # breakpoint()
484+ # return
433485
434486 messages = []
435487 for prompt in prompts [:1 ]:
@@ -462,6 +514,8 @@ def main(
462514 for response in responses :
463515 print (response )
464516
517+ return
518+
465519 else :
466520 # Load tokenizer
467521 tokenizer = AutoTokenizer .from_pretrained (args .model_name )
@@ -495,8 +549,16 @@ def main(
495549 if args .save_model_to_disk :
496550 # Save quantized model
497551 print (f"\n Saving quantized model to: { output_dir } " )
498- quantized_model .save_pretrained (output_dir , safe_serialization = False )
499- tokenizer .save_pretrained (output_dir )
552+ if args .model_name == "meta-llama/Llama-4-Scout-17B-16E-Instruct" :
553+ quantized_model .save_pretrained (
554+ output_dir , safe_serialization = False
555+ )
556+ processor .save_pretrained (output_dir )
557+ else :
558+ quantized_model .save_pretrained (
559+ output_dir , safe_serialization = False
560+ )
561+ tokenizer .save_pretrained (output_dir )
500562
501563 # Push to HuggingFace hub if requested
502564 if args .push_to_hub :
0 commit comments