Skip to content

Commit c099ad5

Browse files
authored
add mxfp8 qat example. (#2316)
1 parent 298671b commit c099ad5

File tree

6 files changed

+526
-0
lines changed

6 files changed

+526
-0
lines changed
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
# Quantization Aware Training (QAT)
2+
3+
Quantization-Aware Training (QAT) is a technique designed to bridge the accuracy gap often observed with Post-Training Quantization (PTQ). Unlike PTQ, which applies quantization after model training, QAT simulates the effects of low-precision arithmetic during the training process itself. This allows the model to adapt its weights and activations to quantization constraints, significantly reducing accuracy degradation. As a result, QAT is particularly effective in preserving model performance even at extremely low precisions, such as MXFP8 or MXFP4, making it a critical approach for deploying efficient, high-performance models on resource-constrained hardware.
4+
5+
## Pre-Requisites
6+
7+
Install the requirements for the example:
8+
9+
```bash
10+
pip install -r requirements.txt
11+
```
12+
13+
## Getting Started
14+
15+
In QAT, a model quantized using `prepare_qat()` can be directly fine-tuned with the original training pipeline. During QAT, the scaling factors inside quantizers are frozen and the model weights are fine-tuned.
16+
17+
### Hugging Face QAT
18+
19+
#### QAT
20+
21+
##### Step 1:
22+
23+
Start by training or fine-tuning your model in its original precision (e.g., BF16). This establishes a strong baseline before introducing quantization.
24+
25+
```
26+
accelerate launch --config-file accelerate_config/fsdp1.yaml \
27+
--fsdp_transformer_layer_cls_to_wrap LlamaDecoderLayer \
28+
main.py \
29+
--model_name_or_path meta-llama/Llama-3.1-8B \
30+
--model_max_length 4096 \
31+
--dataloader_drop_last True \
32+
--do_train True \
33+
--do_eval True \
34+
--output_dir ./llama3.1-finetuned \
35+
--dataset Daring-Anteater \
36+
--num_train_epochs 2.0 \
37+
--per_device_train_batch_size 4 \
38+
--per_device_eval_batch_size 4 \
39+
--gradient_accumulation_steps 1 \
40+
--eval_accumulation_steps 1 \
41+
--save_strategy steps \
42+
--save_steps 3000 \
43+
--eval_strategy steps \
44+
--eval_steps 3000 \
45+
--load_best_model_at_end True \
46+
--save_total_limit 2 \
47+
--learning_rate 1e-5 \
48+
--weight_decay 0.0 \
49+
--warmup_ratio 0.1 \
50+
--lr_scheduler_type linear \
51+
--logging_steps 1 \
52+
--report_to tensorboard
53+
```
54+
55+
##### Step 2:
56+
57+
Quantize the trained model using `prepare_qat()` by setting the following flags `--quant_scheme MXFP8 --do_train False`. This inserts fake quantization modules into the model without starting training yet. Then save the model directly to a get post training quantization model.
58+
59+
60+
```
61+
accelerate launch --config-file accelerate_config/fsdp1.yaml \
62+
--fsdp_transformer_layer_cls_to_wrap LlamaDecoderLayer \
63+
main.py \
64+
--model_name_or_path ./llama3.1-finetuned \
65+
--model_max_length 4096 \
66+
--dataloader_drop_last True \
67+
--do_train False \
68+
--do_eval False \
69+
--quant_scheme MXFP8 \
70+
--output_dir ./llama3.1-finetuned-ptq \
71+
--dataset Daring-Anteater \
72+
--num_train_epochs 2.0 \
73+
--per_device_train_batch_size 4 \
74+
--per_device_eval_batch_size 4 \
75+
--gradient_accumulation_steps 1 \
76+
--eval_accumulation_steps 1 \
77+
--save_strategy steps \
78+
--save_steps 3000 \
79+
--eval_strategy steps \
80+
--eval_steps 3000 \
81+
--load_best_model_at_end True \
82+
--save_total_limit 2 \
83+
--learning_rate 1e-5 \
84+
--weight_decay 0.0 \
85+
--warmup_ratio 0.1 \
86+
--lr_scheduler_type linear \
87+
--logging_steps 1 \
88+
--report_to tensorboard
89+
```
90+
91+
##### Step 3:
92+
93+
Train/fine-tune the quantized model with a small learning rate, e.g. 1e-5 for Adam optimizer by setting `--quant_scheme MXFP8 --do_train True`
94+
95+
```
96+
accelerate launch --config-file accelerate_config/fsdp1.yaml \
97+
--fsdp_transformer_layer_cls_to_wrap LlamaDecoderLayer \
98+
main.py \
99+
--model_name_or_path ./llama3.1-finetuned \
100+
--model_max_length 4096 \
101+
--dataloader_drop_last True \
102+
--do_train True \
103+
--do_eval True \
104+
--quant_scheme MXFP8 \
105+
--output_dir ./llama3.1-finetuned-qat \
106+
--dataset Daring-Anteater \
107+
--max_steps 1000 \
108+
--per_device_train_batch_size 4 \
109+
--per_device_eval_batch_size 4 \
110+
--gradient_accumulation_steps 1 \
111+
--eval_accumulation_steps 1 \
112+
--save_strategy steps \
113+
--save_steps 3000 \
114+
--eval_strategy steps \
115+
--eval_steps 3000 \
116+
--load_best_model_at_end True \
117+
--save_total_limit 2 \
118+
--learning_rate 1e-5 \
119+
--weight_decay 0.0 \
120+
--warmup_ratio 0.03 \
121+
--lr_scheduler_type linear \
122+
--logging_steps 1 \
123+
--report_to tensorboard
124+
```
125+
126+
#### Evaluation
127+
128+
Once QAT is complete, the saved quantized model can be deployed using vLLM for efficient inference. For example, to evaluate on GSM8K:
129+
130+
```
131+
lm_eval \
132+
--model vllm \
133+
--model_args pretrained=./llama3.1-finetuned-qat,tensor_parallel_size=1,data_parallel_size=1,gpu_memory_utilization=0.3,max_model_len=32768,enforce_eager=True \
134+
--tasks gsm8k \
135+
--batch_size 8
136+
```
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
compute_environment: LOCAL_MACHINE
2+
debug: false
3+
distributed_type: MULTI_GPU
4+
downcast_bf16: 'no'
5+
enable_cpu_affinity: false
6+
gpu_ids: all
7+
machine_rank: 0
8+
main_training_function: main
9+
mixed_precision: bf16
10+
num_machines: 1
11+
num_processes: gpu
12+
rdzv_backend: static
13+
same_network: true
14+
tpu_env: []
15+
tpu_use_cluster: false
16+
tpu_use_sudo: false
17+
use_cpu: false
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
compute_environment: LOCAL_MACHINE
2+
debug: false
3+
distributed_type: FSDP
4+
downcast_bf16: 'no'
5+
enable_cpu_affinity: false
6+
fsdp_config:
7+
fsdp_activation_checkpointing: true
8+
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
9+
fsdp_backward_prefetch: BACKWARD_PRE
10+
fsdp_cpu_ram_efficient_loading: true
11+
fsdp_forward_prefetch: false
12+
fsdp_offload_params: false
13+
fsdp_reshard_after_forward: FULL_SHARD
14+
fsdp_state_dict_type: FULL_STATE_DICT
15+
fsdp_sync_module_states: true
16+
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
17+
fsdp_use_orig_params: true
18+
fsdp_version: 1
19+
machine_rank: 0
20+
main_training_function: main
21+
mixed_precision: bf16
22+
num_machines: 1
23+
num_processes: gpu
24+
rdzv_backend: static
25+
same_network: true
26+
tpu_env: []
27+
tpu_use_cluster: false
28+
tpu_use_sudo: false
29+
use_cpu: false
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
import logging
2+
import os
3+
import sys
4+
from dataclasses import dataclass, field
5+
from warnings import warn
6+
7+
import torch
8+
import transformers
9+
from transformers.trainer_utils import get_last_checkpoint
10+
from transformers import (
11+
AutoConfig,
12+
AutoModelForCausalLM,
13+
AutoTokenizer,
14+
HfArgumentParser,
15+
Trainer,
16+
default_data_collator,
17+
set_seed,
18+
TrainerCallback,
19+
)
20+
21+
from utils import (
22+
get_metrics_with_perplexity,
23+
make_supervised_data_module,
24+
)
25+
26+
logger = logging.getLogger(__name__)
27+
28+
@dataclass
29+
class ModelArguments:
30+
model_name_or_path: str = field(default="meta-llama/Llama-3.1-8B")
31+
32+
@dataclass
33+
class TrainingArguments(transformers.TrainingArguments):
34+
cache_dir: str | None = field(default=None)
35+
model_max_length: int = field(
36+
default=2048,
37+
metadata={
38+
"help": (
39+
"Maximum sequence length. Sequences will be right padded (and possibly truncated)."
40+
)
41+
},
42+
)
43+
dataloader_drop_last: bool = field(default=True)
44+
bf16: bool = field(default=True)
45+
46+
@dataclass
47+
class DataArguments:
48+
dataset: str = field(
49+
default="Daring-Anteater",
50+
metadata={"help": "Specify the dataset.", "choices": ["Daring-Anteater"]},
51+
)
52+
train_size: int = field(
53+
default=0,
54+
metadata={"help": "Number of training samples to use. If `0`, use default training size."},
55+
)
56+
eval_size: int = field(
57+
default=0,
58+
metadata={
59+
"help": "Number of evaluation samples to use. If `0`, use default evaluation size."
60+
},
61+
)
62+
63+
@dataclass
64+
class QuantizationArguments:
65+
quant_scheme: str | None = field(
66+
default=None,
67+
metadata={
68+
"help": (
69+
"Specify the quantization format for PTQ/QAT. if specified, PTQ/QAT will be enabled"
70+
" with the specified quantization format"
71+
),
72+
"choices": ["MXFP8"],
73+
},
74+
)
75+
76+
77+
def train():
78+
parser = HfArgumentParser(
79+
(ModelArguments, TrainingArguments, DataArguments, QuantizationArguments)
80+
)
81+
82+
model_args, training_args, data_args, quant_args = parser.parse_args_into_dataclasses()
83+
84+
# Setup logging
85+
logging.basicConfig(
86+
level=logging.INFO,
87+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
88+
datefmt="%m/%d/%Y %H:%M:%S",
89+
handlers=[logging.StreamHandler(sys.stdout)],
90+
)
91+
92+
# Log on each process the small summary:
93+
logger.warning(
94+
f"Process rank: {training_args.local_process_index}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
95+
+ f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
96+
)
97+
# Set seed before initializing model.
98+
set_seed(training_args.seed)
99+
100+
logger.info(f"arguments: {model_args}, {training_args}, {data_args}, {quant_args}")
101+
102+
# Detecting last checkpoint.
103+
last_checkpoint = None
104+
if os.path.isdir(training_args.output_dir) and training_args.do_train:
105+
last_checkpoint = get_last_checkpoint(training_args.output_dir)
106+
logger.info(f"Last checkpoint detected: {last_checkpoint}")
107+
108+
109+
model = AutoModelForCausalLM.from_pretrained(
110+
model_args.model_name_or_path,
111+
cache_dir=training_args.cache_dir,
112+
torch_dtype=torch.bfloat16,
113+
)
114+
model.generation_config.do_sample = True
115+
tokenizer = AutoTokenizer.from_pretrained(
116+
model_args.model_name_or_path, model_max_length=training_args.model_max_length
117+
)
118+
tokenizer.pad_token_id = tokenizer.eos_token_id
119+
120+
# We set model.config.use_cache to False for training when gradient_checkpointing=False.
121+
# Currently useful for FSDP2 to allow for setting activation_checkpointing=True in the config file.
122+
model.config.use_cache = False
123+
124+
# prepare model for quantization
125+
if quant_args.quant_scheme is not None:
126+
from neural_compressor.torch.quantization.quantize import prepare_qat
127+
# inplace
128+
# default mxfp8
129+
prepare_qat(model)
130+
131+
logger.info("Finish model preparation for QAT.")
132+
133+
logger.info("Loading dataset......")
134+
135+
# reuse the dataset function, TODO: preprocess a new dataset
136+
data_module = make_supervised_data_module(
137+
dataset=data_args.dataset,
138+
tokenizer=tokenizer,
139+
train_size=data_args.train_size,
140+
eval_size=data_args.eval_size,
141+
)
142+
143+
# Ensure calibration size doesn't exceed evaluation dataset size
144+
eval_dataset_size = len(data_module["eval_dataset"])
145+
146+
# Training
147+
checkpoint = None
148+
if training_args.resume_from_checkpoint is not None:
149+
checkpoint = training_args.resume_from_checkpoint
150+
elif last_checkpoint is not None:
151+
checkpoint = last_checkpoint
152+
153+
# Torch >= 2.4 throws an error if `use_reentrant` is not set explicitly
154+
if training_args.gradient_checkpointing and training_args.gradient_checkpointing_kwargs is None:
155+
training_args.gradient_checkpointing_kwargs = {"use_reentrant": True}
156+
157+
trainer = Trainer(
158+
model=model,
159+
processing_class=tokenizer,
160+
args=training_args,
161+
**data_module,
162+
)
163+
164+
if training_args.do_train:
165+
logger.info("Starting Train...")
166+
trainer.train(resume_from_checkpoint=checkpoint)
167+
logger.info("Training completed.")
168+
169+
if training_args.do_eval:
170+
logger.info("Starting Evaluation...")
171+
metrics = trainer.evaluate()
172+
metrics = get_metrics_with_perplexity(metrics)
173+
logger.info(f"Evaluation results: \n{metrics}")
174+
175+
if training_args.do_train and quant_args.quant_scheme is None:
176+
logger.info("Saving the model...")
177+
trainer.save_model(training_args.output_dir)
178+
elif quant_args.quant_scheme is not None:
179+
from neural_compressor.torch.export.export_hf import export_hf2compressored_model
180+
# export quantized model for vllm inference using llm-compressor and compressed_tensor
181+
export_hf2compressored_model(model, training_args.output_dir, quant_args.quant_scheme)
182+
if tokenizer is not None:
183+
tokenizer.save_pretrained(training_args.output_dir)
184+
185+
186+
if __name__ == "__main__":
187+
train()
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
auto-round==0.8.0
2+
neural-compressor-pt==3.6
3+
transformers==4.52.4
4+
datasets
5+
sentencepiece>=0.2.0
6+
tensorboardX
7+
peft
8+
accelerate >= 0.12.0
9+
lm-eval==0.4.9.1

0 commit comments

Comments
 (0)