Skip to content

Commit a32a3fb

Browse files
committed
add nvfp4 recipe to llmcompressor script
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 8da2452 commit a32a3fb

File tree

1 file changed

+79
-22
lines changed

1 file changed

+79
-22
lines changed

hf_torchao_vllm/quantize_hf_model_with_llm_compressor.py

Lines changed: 79 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import torch
44

5+
from datasets import load_dataset
56
from transformers import AutoModelForCausalLM, AutoTokenizer
67

78
from llmcompressor import oneshot
@@ -10,33 +11,89 @@
1011

1112
import fire
1213

13-
def run(model_name: str = 'facebook/opt-125m'):
14+
def run(
15+
model_name: str = 'facebook/opt-125m',
16+
quant_type: str = 'fp8',
17+
):
18+
assert quant_type in ('fp8', 'nvfp4'), 'unsupported'
19+
1420
# Load model.
1521
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
1622
print(model)
1723
tokenizer = AutoTokenizer.from_pretrained(model_name)
1824

19-
# Configure the quantization algorithm and scheme.
20-
# In this case, we:
21-
# * quantize the weights to fp8 with per channel via ptq
22-
# * quantize the activations to fp8 with dynamic per token
23-
recipe = QuantizationModifier(
24-
targets="Linear",
25-
scheme="FP8_DYNAMIC",
26-
ignore=[
27-
"lm_head",
28-
# for Qwen MoE, but ok to just hardcode here for now
29-
# https://github.com/vllm-project/llm-compressor/blob/33ef5f497a9801893764c6a2c880cb1f560067fa/examples/quantizing_moe/qwen_example.py#L10
30-
"re:.*mlp.gate$",
31-
"re:.*mlp.shared_expert_gate$",
32-
# also skip attention and shared expert, to focus on MoE for now
33-
"re:.*self_attn.*",
34-
"re:.*shared_expert.*",
35-
],
36-
)
25+
if quant_type == 'fp8':
26+
# Configure the quantization algorithm and scheme.
27+
# In this case, we:
28+
# * quantize the weights to fp8 with per channel via ptq
29+
# * quantize the activations to fp8 with dynamic per token
30+
recipe = QuantizationModifier(
31+
targets="Linear",
32+
scheme="FP8_DYNAMIC",
33+
ignore=[
34+
"lm_head",
35+
# for Qwen MoE, but ok to just hardcode here for now
36+
# https://github.com/vllm-project/llm-compressor/blob/33ef5f497a9801893764c6a2c880cb1f560067fa/examples/quantizing_moe/qwen_example.py#L10
37+
"re:.*mlp.gate$",
38+
"re:.*mlp.shared_expert_gate$",
39+
# also skip attention and shared expert, to focus on MoE for now
40+
"re:.*self_attn.*",
41+
"re:.*shared_expert.*",
42+
],
43+
)
44+
45+
# Apply quantization.
46+
oneshot(model=model, recipe=recipe)
47+
48+
else:
49+
assert quant_type == 'nvfp4', 'unsupported'
50+
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
51+
DATASET_SPLIT = "train_sft"
52+
NUM_CALIBRATION_SAMPLES = 20
53+
MAX_SEQUENCE_LENGTH = 2048
54+
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]")
55+
ds = ds.shuffle(seed=42)
56+
57+
def preprocess(example):
58+
chat_template = "{% for message in messages %}{{ message['content'] }}{% endfor %}"
59+
return {
60+
"text": tokenizer.apply_chat_template(
61+
example["messages"],
62+
tokenize=False,
63+
chat_template=chat_template,
64+
)
65+
}
66+
67+
ds = ds.map(preprocess)
68+
69+
# Tokenize inputs.
70+
def tokenize(sample):
71+
return tokenizer(
72+
sample["text"],
73+
padding=False,
74+
max_length=MAX_SEQUENCE_LENGTH,
75+
truncation=True,
76+
add_special_tokens=False,
77+
)
78+
79+
80+
ds = ds.map(tokenize, remove_columns=ds.column_names)
81+
82+
# Configure the quantization algorithm and scheme.
83+
# In this case, we:
84+
# * quantize the weights to fp4 with per group 16 via ptq
85+
# * calibrate a global_scale for activations, which will be used to
86+
# quantize activations to fp4 on the fly
87+
recipe = QuantizationModifier(targets="Linear", scheme="NVFP4", ignore=["lm_head"])
3788

38-
# Apply quantization.
39-
oneshot(model=model, recipe=recipe)
89+
# Apply quantization.
90+
oneshot(
91+
model=model,
92+
dataset=ds,
93+
recipe=recipe,
94+
max_seq_length=MAX_SEQUENCE_LENGTH,
95+
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
96+
)
4097

4198
# Confirm generations of the quantized model look sane.
4299
print("========== SAMPLE GENERATION ==============")
@@ -49,7 +106,7 @@ def run(model_name: str = 'facebook/opt-125m'):
49106
print("==========================================")
50107

51108
# Save to disk in compressed-tensors format.
52-
SAVE_DIR = "data/llmcompressor/" + "fp8-" + model_name.rstrip("/").split("/")[-1]
109+
SAVE_DIR = "data/llmcompressor/" + quant_type + "-" + model_name.rstrip("/").split("/")[-1]
53110
model.save_pretrained(SAVE_DIR)
54111
tokenizer.save_pretrained(SAVE_DIR)
55112

0 commit comments

Comments
 (0)