Skip to content

Commit 4d43646

Browse files
committed
[wip] llama 4 scout expert quant
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 593a1d1 commit 4d43646

File tree

1 file changed

+64
-2
lines changed

1 file changed

+64
-2
lines changed

hf_torchao_vllm/quantize_hf_model_with_torchao.py

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def get_quantization_config(args):
8383
# TODO tool to find this (I used bisect on this tiny model).
8484
activation_value_lb=1.0e-12,
8585
)
86+
8687
if args.experts_only_qwen_1_5_moe_a_2_7b:
8788
expert_fqn_to_config = {}
8889
# TODO(future PR): this is annoying, I should be able to use a regex here
@@ -125,6 +126,32 @@ def get_quantization_config(args):
125126
return TorchAoConfig(
126127
quant_type=module_fqn_to_config,
127128
)
129+
elif args.ffn_only_llama_4_scout:
130+
# TODO gate this properly
131+
expert_3d_weight_single_config = Float8DynamicActivationFloat8WeightConfig(
132+
# the weights of this model are stored in (B, K, N) layout, and we
133+
# need to quantize rowwise across the K axis, which is `PerRow(1)`.
134+
granularity=[PerRow(), PerRow(1)],
135+
# the 125m model has a lot of activation zeroes for some
136+
# prompts, need to set a lower bound to prevent scales from
137+
# being 0.
138+
# TODO seems like torchao should do this for me.
139+
# TODO tool to find this (I used bisect on this tiny model).
140+
activation_value_lb=1.0e-12,
141+
)
142+
module_fqn_to_config = ModuleFqnToConfig(
143+
{
144+
r"re:.*\.feed_forward\.experts\.gate_up_proj": expert_3d_weight_single_config,
145+
r"re:.*\.feed_forward\.experts\.down_proj": expert_3d_weight_single_config,
146+
r"re:.*\.shared_expert\.down_proj": single_config,
147+
r"re:.*\.shared_expert\.up_proj": single_config,
148+
r"re:.*\.shared_expert\.gate_proj": single_config,
149+
}
150+
)
151+
return TorchAoConfig(
152+
quant_type=module_fqn_to_config,
153+
)
154+
128155
else:
129156
return TorchAoConfig(single_config)
130157
case "int4_weight_only":
@@ -345,6 +372,7 @@ def main(
345372
device_map: str = "cuda",
346373
experts_only_qwen_1_5_moe_a_2_7b: bool = False,
347374
skip_gate_qwen_1_5_moe_a_2_7b: bool = False,
375+
ffn_only_llama_4_scout: bool = False,
348376
save_model_to_disk: bool = True,
349377
):
350378
"""
@@ -363,6 +391,7 @@ def main(
363391
device_map: Device mapping strategy
364392
experts_only_qwen_1_5_moe_a_2_7b: if True, quantizes experts only for Qwen1.5-MoE-A2.7B model
365393
skip_gate_qwen_1_5_moe_a_2_7b: if True, skips gate quantization for Qwen1.5-MoE-A2.7B model
394+
ffn_only_llama_4_scout: if True, FFN only for meta-llama/Llama-4-Scout-17B-16E-Instruct
366395
save_model_to_disk: if True, saves quantized model to local disk
367396
"""
368397
# Test prompts
@@ -397,6 +426,7 @@ def main(
397426
experts_only_qwen_1_5_moe_a_2_7b=experts_only_qwen_1_5_moe_a_2_7b,
398427
save_model_to_disk=save_model_to_disk,
399428
skip_gate_qwen_1_5_moe_a_2_7b=skip_gate_qwen_1_5_moe_a_2_7b,
429+
ffn_only_llama_4_scout=ffn_only_llama_4_scout,
400430
)
401431
print(f"{args=}")
402432

@@ -415,6 +445,7 @@ def main(
415445
# Get quantization config
416446
quantization_config = get_quantization_config(args)
417447

448+
# TODO(before land): clean up the chat processor code
418449
if args.model_name == "meta-llama/Llama-4-Scout-17B-16E-Instruct":
419450
# TODO(future): maybe unify with the else branch, need to figure
420451
# out the right syntax for preparing inputs and running generation
@@ -430,6 +461,27 @@ def main(
430461
torch_dtype=torch.bfloat16,
431462
quantization_config=quantization_config,
432463
)
464+
print(quantized_model)
465+
466+
print(
467+
"quantized_model.language_model.model.layers[47].feed_forward.experts.down_proj",
468+
type(
469+
quantized_model.language_model.model.layers[
470+
47
471+
].feed_forward.experts.down_proj
472+
),
473+
)
474+
print(
475+
"quantized_model.language_model.model.layers[47].feed_forward.experts.gate_up_proj",
476+
type(
477+
quantized_model.language_model.model.layers[
478+
47
479+
].feed_forward.experts.gate_up_proj
480+
),
481+
)
482+
483+
# breakpoint()
484+
# return
433485

434486
messages = []
435487
for prompt in prompts[:1]:
@@ -462,6 +514,8 @@ def main(
462514
for response in responses:
463515
print(response)
464516

517+
return
518+
465519
else:
466520
# Load tokenizer
467521
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
@@ -495,8 +549,16 @@ def main(
495549
if args.save_model_to_disk:
496550
# Save quantized model
497551
print(f"\nSaving quantized model to: {output_dir}")
498-
quantized_model.save_pretrained(output_dir, safe_serialization=False)
499-
tokenizer.save_pretrained(output_dir)
552+
if args.model_name == "meta-llama/Llama-4-Scout-17B-16E-Instruct":
553+
quantized_model.save_pretrained(
554+
output_dir, safe_serialization=False
555+
)
556+
processor.save_pretrained(output_dir)
557+
else:
558+
quantized_model.save_pretrained(
559+
output_dir, safe_serialization=False
560+
)
561+
tokenizer.save_pretrained(output_dir)
500562

501563
# Push to HuggingFace hub if requested
502564
if args.push_to_hub:

0 commit comments

Comments
 (0)