22
33import torch
44
5+ from datasets import load_dataset
56from transformers import AutoModelForCausalLM , AutoTokenizer
67
78from llmcompressor import oneshot
1011
1112import fire
1213
13- def run (model_name : str = 'facebook/opt-125m' ):
14+ def run (
15+ model_name : str = 'facebook/opt-125m' ,
16+ quant_type : str = 'fp8' ,
17+ ):
18+ assert quant_type in ('fp8' , 'nvfp4' ), 'unsupported'
19+
1420 # Load model.
1521 model = AutoModelForCausalLM .from_pretrained (model_name , torch_dtype = torch .bfloat16 )
1622 print (model )
1723 tokenizer = AutoTokenizer .from_pretrained (model_name )
1824
19- # Configure the quantization algorithm and scheme.
20- # In this case, we:
21- # * quantize the weights to fp8 with per channel via ptq
22- # * quantize the activations to fp8 with dynamic per token
23- recipe = QuantizationModifier (
24- targets = "Linear" ,
25- scheme = "FP8_DYNAMIC" ,
26- ignore = [
27- "lm_head" ,
28- # for Qwen MoE, but ok to just hardcode here for now
29- # https://github.com/vllm-project/llm-compressor/blob/33ef5f497a9801893764c6a2c880cb1f560067fa/examples/quantizing_moe/qwen_example.py#L10
30- "re:.*mlp.gate$" ,
31- "re:.*mlp.shared_expert_gate$" ,
32- # also skip attention and shared expert, to focus on MoE for now
33- "re:.*self_attn.*" ,
34- "re:.*shared_expert.*" ,
35- ],
36- )
25+ if quant_type == 'fp8' :
26+ # Configure the quantization algorithm and scheme.
27+ # In this case, we:
28+ # * quantize the weights to fp8 with per channel via ptq
29+ # * quantize the activations to fp8 with dynamic per token
30+ recipe = QuantizationModifier (
31+ targets = "Linear" ,
32+ scheme = "FP8_DYNAMIC" ,
33+ ignore = [
34+ "lm_head" ,
35+ # for Qwen MoE, but ok to just hardcode here for now
36+ # https://github.com/vllm-project/llm-compressor/blob/33ef5f497a9801893764c6a2c880cb1f560067fa/examples/quantizing_moe/qwen_example.py#L10
37+ "re:.*mlp.gate$" ,
38+ "re:.*mlp.shared_expert_gate$" ,
39+ # also skip attention and shared expert, to focus on MoE for now
40+ "re:.*self_attn.*" ,
41+ "re:.*shared_expert.*" ,
42+ ],
43+ )
44+
45+ # Apply quantization.
46+ oneshot (model = model , recipe = recipe )
47+
48+ else :
49+ assert quant_type == 'nvfp4' , 'unsupported'
50+ DATASET_ID = "HuggingFaceH4/ultrachat_200k"
51+ DATASET_SPLIT = "train_sft"
52+ NUM_CALIBRATION_SAMPLES = 20
53+ MAX_SEQUENCE_LENGTH = 2048
54+ ds = load_dataset (DATASET_ID , split = f"{ DATASET_SPLIT } [:{ NUM_CALIBRATION_SAMPLES } ]" )
55+ ds = ds .shuffle (seed = 42 )
56+
57+ def preprocess (example ):
58+ chat_template = "{% for message in messages %}{{ message['content'] }}{% endfor %}"
59+ return {
60+ "text" : tokenizer .apply_chat_template (
61+ example ["messages" ],
62+ tokenize = False ,
63+ chat_template = chat_template ,
64+ )
65+ }
66+
67+ ds = ds .map (preprocess )
68+
69+ # Tokenize inputs.
70+ def tokenize (sample ):
71+ return tokenizer (
72+ sample ["text" ],
73+ padding = False ,
74+ max_length = MAX_SEQUENCE_LENGTH ,
75+ truncation = True ,
76+ add_special_tokens = False ,
77+ )
78+
79+
80+ ds = ds .map (tokenize , remove_columns = ds .column_names )
81+
82+ # Configure the quantization algorithm and scheme.
83+ # In this case, we:
84+ # * quantize the weights to fp4 with per group 16 via ptq
85+ # * calibrate a global_scale for activations, which will be used to
86+ # quantize activations to fp4 on the fly
87+ recipe = QuantizationModifier (targets = "Linear" , scheme = "NVFP4" , ignore = ["lm_head" ])
3788
38- # Apply quantization.
39- oneshot (model = model , recipe = recipe )
89+ # Apply quantization.
90+ oneshot (
91+ model = model ,
92+ dataset = ds ,
93+ recipe = recipe ,
94+ max_seq_length = MAX_SEQUENCE_LENGTH ,
95+ num_calibration_samples = NUM_CALIBRATION_SAMPLES ,
96+ )
4097
4198 # Confirm generations of the quantized model look sane.
4299 print ("========== SAMPLE GENERATION ==============" )
@@ -49,7 +106,7 @@ def run(model_name: str = 'facebook/opt-125m'):
49106 print ("==========================================" )
50107
51108 # Save to disk in compressed-tensors format.
52- SAVE_DIR = "data/llmcompressor/" + "fp8 -" + model_name .rstrip ("/" ).split ("/" )[- 1 ]
109+ SAVE_DIR = "data/llmcompressor/" + quant_type + " -" + model_name .rstrip ("/" ).split ("/" )[- 1 ]
53110 model .save_pretrained (SAVE_DIR )
54111 tokenizer .save_pretrained (SAVE_DIR )
55112
0 commit comments