Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 122 additions & 50 deletions hf_torchao_vllm/quantize_hf_model_with_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def get_quantization_config(args):
# TODO tool to find this (I used bisect on this tiny model).
activation_value_lb=1.0e-12,
)

if args.experts_only_qwen_1_5_moe_a_2_7b:
expert_fqn_to_config = {}
# TODO(future PR): this is annoying, I should be able to use a regex here
Expand Down Expand Up @@ -125,6 +126,32 @@ def get_quantization_config(args):
return TorchAoConfig(
quant_type=module_fqn_to_config,
)
elif args.ffn_only_llama_4_scout:
# TODO gate this properly
expert_3d_weight_single_config = Float8DynamicActivationFloat8WeightConfig(
# the weights of this model are stored in (B, K, N) layout, and we
# need to quantize rowwise across the K axis, which is `PerRow(1)`.
granularity=[PerRow(), PerRow(1)],
# the 125m model has a lot of activation zeroes for some
# prompts, need to set a lower bound to prevent scales from
# being 0.
# TODO seems like torchao should do this for me.
# TODO tool to find this (I used bisect on this tiny model).
activation_value_lb=1.0e-12,
)
module_fqn_to_config = ModuleFqnToConfig(
{
r"re:.*\.feed_forward\.experts\.gate_up_proj": expert_3d_weight_single_config,
r"re:.*\.feed_forward\.experts\.down_proj": expert_3d_weight_single_config,
r"re:.*\.shared_expert\.down_proj": single_config,
r"re:.*\.shared_expert\.up_proj": single_config,
r"re:.*\.shared_expert\.gate_proj": single_config,
}
)
return TorchAoConfig(
quant_type=module_fqn_to_config,
)

else:
return TorchAoConfig(single_config)
case "int4_weight_only":
Expand Down Expand Up @@ -318,6 +345,44 @@ def benchmark_model(model, input_ids, max_new_tokens, name=""):
return elapsed


def _inference_with_processor(
model,
processor,
prompts,
args,
) -> None:
messages = []
for prompt in prompts[:1]:
messages.append(
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
],
},
)

inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
padding="longest",
).to(model.device)

outputs = model.generate(
**inputs,
max_new_tokens=args.max_new_tokens,
)

responses = processor.batch_decode(
outputs[:, inputs["input_ids"].shape[-1] :]
)
for response in responses:
print(response)


def main(
model_name: str = "facebook/opt-125m",
output_dir: str | None = None,
Expand Down Expand Up @@ -345,6 +410,7 @@ def main(
device_map: str = "cuda",
experts_only_qwen_1_5_moe_a_2_7b: bool = False,
skip_gate_qwen_1_5_moe_a_2_7b: bool = False,
ffn_only_llama_4_scout: bool = False,
save_model_to_disk: bool = True,
):
"""
Expand All @@ -363,6 +429,7 @@ def main(
device_map: Device mapping strategy
experts_only_qwen_1_5_moe_a_2_7b: if True, quantizes experts only for Qwen1.5-MoE-A2.7B model
skip_gate_qwen_1_5_moe_a_2_7b: if True, skips gate quantization for Qwen1.5-MoE-A2.7B model
ffn_only_llama_4_scout: if True, FFN only for meta-llama/Llama-4-Scout-17B-16E-Instruct
save_model_to_disk: if True, saves quantized model to local disk
"""
# Test prompts
Expand Down Expand Up @@ -397,6 +464,7 @@ def main(
experts_only_qwen_1_5_moe_a_2_7b=experts_only_qwen_1_5_moe_a_2_7b,
save_model_to_disk=save_model_to_disk,
skip_gate_qwen_1_5_moe_a_2_7b=skip_gate_qwen_1_5_moe_a_2_7b,
ffn_only_llama_4_scout=ffn_only_llama_4_scout,
)
print(f"{args=}")

Expand Down Expand Up @@ -430,37 +498,26 @@ def main(
torch_dtype=torch.bfloat16,
quantization_config=quantization_config,
)
print(quantized_model)

messages = []
for prompt in prompts[:1]:
messages.append(
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
],
},
)

inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
padding="longest",
).to(quantized_model.device)

outputs = quantized_model.generate(
**inputs,
max_new_tokens=args.max_new_tokens,
print(
"quantized_model.language_model.model.layers[47].feed_forward.experts.down_proj",
type(
quantized_model.language_model.model.layers[
47
].feed_forward.experts.down_proj
),
)

responses = processor.batch_decode(
outputs[:, inputs["input_ids"].shape[-1] :]
print(
"quantized_model.language_model.model.layers[47].feed_forward.experts.gate_up_proj",
type(
quantized_model.language_model.model.layers[
47
].feed_forward.experts.gate_up_proj
),
)
for response in responses:
print(response)

_inference_with_processor(quantized_model, processor, prompts, args)

else:
# Load tokenizer
Expand Down Expand Up @@ -495,8 +552,16 @@ def main(
if args.save_model_to_disk:
# Save quantized model
print(f"\nSaving quantized model to: {output_dir}")
quantized_model.save_pretrained(output_dir, safe_serialization=False)
tokenizer.save_pretrained(output_dir)
if args.model_name == "meta-llama/Llama-4-Scout-17B-16E-Instruct":
quantized_model.save_pretrained(
output_dir, safe_serialization=False
)
processor.save_pretrained(output_dir)
else:
quantized_model.save_pretrained(
output_dir, safe_serialization=False
)
tokenizer.save_pretrained(output_dir)

# Push to HuggingFace hub if requested
if args.push_to_hub:
Expand All @@ -509,28 +574,35 @@ def main(

if args.save_model_to_disk:
# TODO(future): support this for LLaMa 4 Scout
# Load saved model to verify
print("\nLoading saved quantized model to verify...")
# TODO: do we really need `weights_only=False` here?
loaded_model = AutoModelForCausalLM.from_pretrained(
output_dir,
device_map=args.device_map,
torch_dtype="auto",
weights_only=False,
)
if args.model_name == "meta-llama/Llama-4-Scout-17B-16E-Instruct":
_inference_with_processor(quantized_model, processor, prompts, args)
print("Verified inference with reloaded model")

else:
# Load saved model to verify
# TODO: do we really need `weights_only=False` here?
loaded_model = AutoModelForCausalLM.from_pretrained(
output_dir,
device_map=args.device_map,
torch_dtype="auto",
weights_only=False,
)

# Test loaded model with first prompt
test_prompt = prompts[0]
input_ids = tokenizer(test_prompt, return_tensors="pt").to(
loaded_model.device
)
output = loaded_model.generate(
**input_ids, max_new_tokens=args.max_new_tokens
)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
print(
f"Verification - Prompt: {test_prompt!r}, Generated text: {generated_text!r}"
)
# Test loaded model with first prompt
test_prompt = prompts[0]
input_ids = tokenizer(test_prompt, return_tensors="pt").to(
loaded_model.device
)
output = loaded_model.generate(
**input_ids, max_new_tokens=args.max_new_tokens
)
generated_text = tokenizer.decode(
output[0], skip_special_tokens=True
)
print(
f"Verification - Prompt: {test_prompt!r}, Generated text: {generated_text!r}"
)

# Benchmark if requested
if args.benchmark:
Expand Down
3 changes: 3 additions & 0 deletions hf_torchao_vllm/run_quantized_model_in_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def main(
model=model_name,
tensor_parallel_size=tp_size,
enforce_eager=not compile,
max_model_len=max_tokens,
)

# Print diagnostic information
Expand All @@ -86,6 +87,8 @@ def main(
print(f"model_config: {model_config}")
print(f"hf_config: {model_config.hf_config}")
if print_model:
# TODO: fix this for latest vllm, lines below crash when building from
# source with https://www.internalfb.com/phabricator/paste/view/P2028278010
model = llm.llm_engine.model_executor.driver_worker.model_runner.model
print(f"model: {model}")

Expand Down