Skip to content

Commit 0e7b53e

Browse files
authored
Merge pull request #79 from vkuzo/20251104_llama4_scout
example of llama 4 scout torchao quant
2 parents ee604d2 + d8528a7 commit 0e7b53e

File tree

1 file changed

+82
-30
lines changed

1 file changed

+82
-30
lines changed

hf_torchao_vllm/quantize_hf_model_with_torchao.py

Lines changed: 82 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,15 @@ def main(
365365
skip_gate_qwen_1_5_moe_a_2_7b: if True, skips gate quantization for Qwen1.5-MoE-A2.7B model
366366
save_model_to_disk: if True, saves quantized model to local disk
367367
"""
368+
# Test prompts
369+
prompts = [
370+
"Why is Pytorch 2.0 the best machine learning compiler?",
371+
"Hello, my name is",
372+
"The president of the United States is",
373+
"The capital of France is",
374+
"The future of AI is",
375+
]
376+
368377
# Set seed before creating the model
369378
set_seed(42)
370379

@@ -406,40 +415,82 @@ def main(
406415
# Get quantization config
407416
quantization_config = get_quantization_config(args)
408417

409-
# Load and quantize model
410-
print("Loading and quantizing model...")
411-
quantized_model = AutoModelForCausalLM.from_pretrained(
412-
args.model_name,
413-
torch_dtype="bfloat16",
414-
device_map=args.device_map,
415-
quantization_config=quantization_config,
416-
)
417-
print(quantized_model)
418+
if args.model_name == "meta-llama/Llama-4-Scout-17B-16E-Instruct":
419+
# TODO(future): maybe unify with the else branch, need to figure
420+
# out the right syntax for preparing inputs and running generation
421+
# TODO(future): make this work for multiple prompts
422+
from transformers import AutoProcessor, Llama4ForConditionalGeneration
418423

419-
# Load tokenizer
420-
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
424+
processor = AutoProcessor.from_pretrained(args.model_name)
425+
quantized_model = Llama4ForConditionalGeneration.from_pretrained(
426+
args.model_name,
427+
# Note: flex does not work with naive device_map="auto"
428+
# attn_implementation="flex_attention",
429+
device_map="auto",
430+
torch_dtype=torch.bfloat16,
431+
quantization_config=quantization_config,
432+
)
421433

422-
# Test prompts
423-
prompts = [
424-
"Why is Pytorch 2.0 the best machine learning compiler?",
425-
"Hello, my name is",
426-
"The president of the United States is",
427-
"The capital of France is",
428-
"The future of AI is",
429-
]
434+
messages = []
435+
for prompt in prompts[:1]:
436+
messages.append(
437+
{
438+
"role": "user",
439+
"content": [
440+
{"type": "text", "text": prompt},
441+
],
442+
},
443+
)
430444

431-
# Test generation
432-
print("\nTesting quantized model generation...")
433-
input_ids = tokenizer(prompts, return_tensors="pt", padding=True).to(
434-
quantized_model.device
435-
)
436-
outputs = quantized_model.generate(
437-
**input_ids, max_new_tokens=args.max_new_tokens
438-
)
445+
inputs = processor.apply_chat_template(
446+
messages,
447+
add_generation_prompt=True,
448+
tokenize=True,
449+
return_dict=True,
450+
return_tensors="pt",
451+
padding="longest",
452+
).to(quantized_model.device)
453+
454+
outputs = quantized_model.generate(
455+
**inputs,
456+
max_new_tokens=args.max_new_tokens,
457+
)
458+
459+
responses = processor.batch_decode(
460+
outputs[:, inputs["input_ids"].shape[-1] :]
461+
)
462+
for response in responses:
463+
print(response)
464+
465+
else:
466+
# Load tokenizer
467+
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
468+
# breakpoint()
469+
470+
# Load and quantize model
471+
print("Loading and quantizing model...")
472+
quantized_model = AutoModelForCausalLM.from_pretrained(
473+
args.model_name,
474+
torch_dtype="bfloat16",
475+
device_map=args.device_map,
476+
quantization_config=quantization_config,
477+
)
478+
print(quantized_model)
479+
480+
# Test generation
481+
print("\nTesting quantized model generation...")
482+
input_ids = tokenizer(prompts, return_tensors="pt", padding=True).to(
483+
quantized_model.device
484+
)
485+
outputs = quantized_model.generate(
486+
**input_ids, max_new_tokens=args.max_new_tokens
487+
)
439488

440-
for i, (prompt, output) in enumerate(zip(prompts, outputs, strict=False)):
441-
generated_text = tokenizer.decode(output, skip_special_tokens=True)
442-
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
489+
for i, (prompt, output) in enumerate(
490+
zip(prompts, outputs, strict=False)
491+
):
492+
generated_text = tokenizer.decode(output, skip_special_tokens=True)
493+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
443494

444495
if args.save_model_to_disk:
445496
# Save quantized model
@@ -457,6 +508,7 @@ def main(
457508
tokenizer.push_to_hub(model_name)
458509

459510
if args.save_model_to_disk:
511+
# TODO(future): support this for LLaMa 4 Scout
460512
# Load saved model to verify
461513
print("\nLoading saved quantized model to verify...")
462514
# TODO: do we really need `weights_only=False` here?

0 commit comments

Comments
 (0)