Skip to content

Commit f0be3c9

Browse files
committed
[wip] llama 4 scout expert quant
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 593a1d1 commit f0be3c9

File tree

2 files changed

+125
-50
lines changed

2 files changed

+125
-50
lines changed

hf_torchao_vllm/quantize_hf_model_with_torchao.py

Lines changed: 122 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def get_quantization_config(args):
8383
# TODO tool to find this (I used bisect on this tiny model).
8484
activation_value_lb=1.0e-12,
8585
)
86+
8687
if args.experts_only_qwen_1_5_moe_a_2_7b:
8788
expert_fqn_to_config = {}
8889
# TODO(future PR): this is annoying, I should be able to use a regex here
@@ -125,6 +126,32 @@ def get_quantization_config(args):
125126
return TorchAoConfig(
126127
quant_type=module_fqn_to_config,
127128
)
129+
elif args.ffn_only_llama_4_scout:
130+
# TODO gate this properly
131+
expert_3d_weight_single_config = Float8DynamicActivationFloat8WeightConfig(
132+
# the weights of this model are stored in (B, K, N) layout, and we
133+
# need to quantize rowwise across the K axis, which is `PerRow(1)`.
134+
granularity=[PerRow(), PerRow(1)],
135+
# the 125m model has a lot of activation zeroes for some
136+
# prompts, need to set a lower bound to prevent scales from
137+
# being 0.
138+
# TODO seems like torchao should do this for me.
139+
# TODO tool to find this (I used bisect on this tiny model).
140+
activation_value_lb=1.0e-12,
141+
)
142+
module_fqn_to_config = ModuleFqnToConfig(
143+
{
144+
r"re:.*\.feed_forward\.experts\.gate_up_proj": expert_3d_weight_single_config,
145+
r"re:.*\.feed_forward\.experts\.down_proj": expert_3d_weight_single_config,
146+
r"re:.*\.shared_expert\.down_proj": single_config,
147+
r"re:.*\.shared_expert\.up_proj": single_config,
148+
r"re:.*\.shared_expert\.gate_proj": single_config,
149+
}
150+
)
151+
return TorchAoConfig(
152+
quant_type=module_fqn_to_config,
153+
)
154+
128155
else:
129156
return TorchAoConfig(single_config)
130157
case "int4_weight_only":
@@ -318,6 +345,44 @@ def benchmark_model(model, input_ids, max_new_tokens, name=""):
318345
return elapsed
319346

320347

348+
def _inference_with_processor(
349+
model,
350+
processor,
351+
prompts,
352+
args,
353+
) -> None:
354+
messages = []
355+
for prompt in prompts[:1]:
356+
messages.append(
357+
{
358+
"role": "user",
359+
"content": [
360+
{"type": "text", "text": prompt},
361+
],
362+
},
363+
)
364+
365+
inputs = processor.apply_chat_template(
366+
messages,
367+
add_generation_prompt=True,
368+
tokenize=True,
369+
return_dict=True,
370+
return_tensors="pt",
371+
padding="longest",
372+
).to(model.device)
373+
374+
outputs = model.generate(
375+
**inputs,
376+
max_new_tokens=args.max_new_tokens,
377+
)
378+
379+
responses = processor.batch_decode(
380+
outputs[:, inputs["input_ids"].shape[-1] :]
381+
)
382+
for response in responses:
383+
print(response)
384+
385+
321386
def main(
322387
model_name: str = "facebook/opt-125m",
323388
output_dir: str | None = None,
@@ -345,6 +410,7 @@ def main(
345410
device_map: str = "cuda",
346411
experts_only_qwen_1_5_moe_a_2_7b: bool = False,
347412
skip_gate_qwen_1_5_moe_a_2_7b: bool = False,
413+
ffn_only_llama_4_scout: bool = False,
348414
save_model_to_disk: bool = True,
349415
):
350416
"""
@@ -363,6 +429,7 @@ def main(
363429
device_map: Device mapping strategy
364430
experts_only_qwen_1_5_moe_a_2_7b: if True, quantizes experts only for Qwen1.5-MoE-A2.7B model
365431
skip_gate_qwen_1_5_moe_a_2_7b: if True, skips gate quantization for Qwen1.5-MoE-A2.7B model
432+
ffn_only_llama_4_scout: if True, FFN only for meta-llama/Llama-4-Scout-17B-16E-Instruct
366433
save_model_to_disk: if True, saves quantized model to local disk
367434
"""
368435
# Test prompts
@@ -397,6 +464,7 @@ def main(
397464
experts_only_qwen_1_5_moe_a_2_7b=experts_only_qwen_1_5_moe_a_2_7b,
398465
save_model_to_disk=save_model_to_disk,
399466
skip_gate_qwen_1_5_moe_a_2_7b=skip_gate_qwen_1_5_moe_a_2_7b,
467+
ffn_only_llama_4_scout=ffn_only_llama_4_scout,
400468
)
401469
print(f"{args=}")
402470

@@ -430,37 +498,26 @@ def main(
430498
torch_dtype=torch.bfloat16,
431499
quantization_config=quantization_config,
432500
)
501+
print(quantized_model)
433502

434-
messages = []
435-
for prompt in prompts[:1]:
436-
messages.append(
437-
{
438-
"role": "user",
439-
"content": [
440-
{"type": "text", "text": prompt},
441-
],
442-
},
443-
)
444-
445-
inputs = processor.apply_chat_template(
446-
messages,
447-
add_generation_prompt=True,
448-
tokenize=True,
449-
return_dict=True,
450-
return_tensors="pt",
451-
padding="longest",
452-
).to(quantized_model.device)
453-
454-
outputs = quantized_model.generate(
455-
**inputs,
456-
max_new_tokens=args.max_new_tokens,
503+
print(
504+
"quantized_model.language_model.model.layers[47].feed_forward.experts.down_proj",
505+
type(
506+
quantized_model.language_model.model.layers[
507+
47
508+
].feed_forward.experts.down_proj
509+
),
457510
)
458-
459-
responses = processor.batch_decode(
460-
outputs[:, inputs["input_ids"].shape[-1] :]
511+
print(
512+
"quantized_model.language_model.model.layers[47].feed_forward.experts.gate_up_proj",
513+
type(
514+
quantized_model.language_model.model.layers[
515+
47
516+
].feed_forward.experts.gate_up_proj
517+
),
461518
)
462-
for response in responses:
463-
print(response)
519+
520+
_inference_with_processor(quantized_model, processor, prompts, args)
464521

465522
else:
466523
# Load tokenizer
@@ -495,8 +552,16 @@ def main(
495552
if args.save_model_to_disk:
496553
# Save quantized model
497554
print(f"\nSaving quantized model to: {output_dir}")
498-
quantized_model.save_pretrained(output_dir, safe_serialization=False)
499-
tokenizer.save_pretrained(output_dir)
555+
if args.model_name == "meta-llama/Llama-4-Scout-17B-16E-Instruct":
556+
quantized_model.save_pretrained(
557+
output_dir, safe_serialization=False
558+
)
559+
processor.save_pretrained(output_dir)
560+
else:
561+
quantized_model.save_pretrained(
562+
output_dir, safe_serialization=False
563+
)
564+
tokenizer.save_pretrained(output_dir)
500565

501566
# Push to HuggingFace hub if requested
502567
if args.push_to_hub:
@@ -509,28 +574,35 @@ def main(
509574

510575
if args.save_model_to_disk:
511576
# TODO(future): support this for LLaMa 4 Scout
512-
# Load saved model to verify
513577
print("\nLoading saved quantized model to verify...")
514-
# TODO: do we really need `weights_only=False` here?
515-
loaded_model = AutoModelForCausalLM.from_pretrained(
516-
output_dir,
517-
device_map=args.device_map,
518-
torch_dtype="auto",
519-
weights_only=False,
520-
)
578+
if args.model_name == "meta-llama/Llama-4-Scout-17B-16E-Instruct":
579+
_inference_with_processor(quantized_model, processor, prompts, args)
580+
print("Verified inference with reloaded model")
581+
582+
else:
583+
# Load saved model to verify
584+
# TODO: do we really need `weights_only=False` here?
585+
loaded_model = AutoModelForCausalLM.from_pretrained(
586+
output_dir,
587+
device_map=args.device_map,
588+
torch_dtype="auto",
589+
weights_only=False,
590+
)
521591

522-
# Test loaded model with first prompt
523-
test_prompt = prompts[0]
524-
input_ids = tokenizer(test_prompt, return_tensors="pt").to(
525-
loaded_model.device
526-
)
527-
output = loaded_model.generate(
528-
**input_ids, max_new_tokens=args.max_new_tokens
529-
)
530-
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
531-
print(
532-
f"Verification - Prompt: {test_prompt!r}, Generated text: {generated_text!r}"
533-
)
592+
# Test loaded model with first prompt
593+
test_prompt = prompts[0]
594+
input_ids = tokenizer(test_prompt, return_tensors="pt").to(
595+
loaded_model.device
596+
)
597+
output = loaded_model.generate(
598+
**input_ids, max_new_tokens=args.max_new_tokens
599+
)
600+
generated_text = tokenizer.decode(
601+
output[0], skip_special_tokens=True
602+
)
603+
print(
604+
f"Verification - Prompt: {test_prompt!r}, Generated text: {generated_text!r}"
605+
)
534606

535607
# Benchmark if requested
536608
if args.benchmark:

hf_torchao_vllm/run_quantized_model_in_vllm.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def main(
7575
model=model_name,
7676
tensor_parallel_size=tp_size,
7777
enforce_eager=not compile,
78+
max_model_len=max_tokens,
7879
)
7980

8081
# Print diagnostic information
@@ -86,6 +87,8 @@ def main(
8687
print(f"model_config: {model_config}")
8788
print(f"hf_config: {model_config.hf_config}")
8889
if print_model:
90+
# TODO: fix this for latest vllm, lines below crash when building from
91+
# source with https://www.internalfb.com/phabricator/paste/view/P2028278010
8992
model = llm.llm_engine.model_executor.driver_worker.model_runner.model
9093
print(f"model: {model}")
9194

0 commit comments

Comments
 (0)