Skip to content

Commit bdc6697

Browse files
LHXuuuchenxi-hh
andauthored
[Quantization] Support compressed tensors w8a8 static and w8a8 dynamic weight (#4036)
### What this PR does / why we need it? While using the LLM Compressor quantization tool from the VLLM community to generate quantized weights, the VLLM Ascend engine needs to be adapted to support the compressed tensors quantization format. 1. Add AscendCompressedTensorsConfig to replace CompressedTensorsConfig in vllm. 2. Support CompressedTensorsW8A8 static weight. - weight: per-channel, int8, symmetric; activation: per-tensor, int8, symmetric. 4. Support CompressedTensorsW8A8Dynamic weight. - weight: per-channel, int8, symmetric; activation: per-token, int8, symmetric, dynamic. 5. Modify the override_quantization_method in AscendQuantConfig. Co-authored-by: taoqun110 taoqun@huawei.com Co-authored-by: chenxi-hh chen464822955@163.com - vLLM version: v0.11.2 --------- Signed-off-by: LHXuuu <scut_xlh@163.com> Signed-off-by: chenxi-hh <chen464822955@163.com> Signed-off-by: chenxi-hh <32731611+chenxi-hh@users.noreply.github.com> Co-authored-by: chenxi-hh <chen464822955@163.com> Co-authored-by: chenxi-hh <32731611+chenxi-hh@users.noreply.github.com>
1 parent ab37a7d commit bdc6697

File tree

18 files changed

+707
-32
lines changed

18 files changed

+707
-32
lines changed

.github/workflows/_e2e_test.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ jobs:
179179
VLLM_USE_MODELSCOPE: True
180180
if: ${{ inputs.type == 'full' }}
181181
run: |
182+
pytest -sv tests/e2e/multicard/test_quantization.py
182183
pytest -sv tests/e2e/multicard/test_aclgraph_capture_replay.py
183184
pytest -sv tests/e2e/multicard/test_torchair_graph_mode.py
184185
pytest -sv tests/e2e/multicard/test_full_graph_mode.py

docs/source/user_guide/feature_guide/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ This section provides a detailed usage guide of vLLM Ascend features.
77
:maxdepth: 1
88
graph_mode
99
quantization
10+
quantization-llm-compressor
1011
sleep_mode
1112
structured_output
1213
lora
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# llm-compressor Quantization Guide
2+
3+
Model quantization is a technique that reduces the size and computational requirements of a model by lowering the data precision of the weights and activation values in the model, thereby saving the memory and improving the inference speed.
4+
5+
## Supported llm-compressor Quantization Types
6+
7+
Support CompressedTensorsW8A8 static weight
8+
9+
weight: per-channel, int8, symmetric; activation: per-tensor, int8, symmetric.
10+
11+
Support CompressedTensorsW8A8Dynamic weight
12+
13+
weight: per-channel, int8, symmetric; activation: per-token, int8, symmetric, dynamic.
14+
15+
## Install llm-compressor
16+
17+
To quantize a model, you should install [llm-compressor](https://github.com/vllm-project/llm-compressor/blob/main/README.md). It is a unified library for creating compressed models for faster inference with vLLM.
18+
19+
Install llm-compressor
20+
21+
```bash
22+
pip install llmcompressor
23+
```
24+
25+
### Generate the W8A8 weights
26+
27+
```bash
28+
cd examples/quantization/llm-compressor
29+
30+
python3 w8a8_int8_dynamic.py
31+
```
32+
33+
for more details, see the [Official Sample](https://github.com/vllm-project/llm-compressor/tree/main/examples).
34+
35+
## Run the model
36+
37+
Now, you can run the quantized model with vLLM Ascend. Examples for online and offline inference are provided as follows:
38+
39+
### Offline inference
40+
41+
```python
42+
import torch
43+
44+
from vllm import LLM, SamplingParams
45+
46+
prompts = [
47+
"Hello, my name is",
48+
"The future of AI is",
49+
]
50+
sampling_params = SamplingParams(temperature=0.6, top_p=0.95, top_k=40)
51+
52+
llm = LLM(model="{quantized_model_save_path}",
53+
max_model_len=2048,
54+
trust_remote_code=True)
55+
56+
outputs = llm.generate(prompts, sampling_params)
57+
for output in outputs:
58+
prompt = output.prompt
59+
generated_text = output.outputs[0].text
60+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
61+
```
62+
63+
### Online inference
64+
65+
Start the quantized model using vLLM Ascend; no modifications to the startup command are required.
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
import os
2+
import torch
3+
4+
from datasets import load_dataset
5+
from transformers import AutoModelForCausalLM, Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration, \
6+
AutoTokenizer, AutoProcessor, AutoConfig, AutoImageProcessor
7+
8+
from llmcompressor import oneshot
9+
from llmcompressor.modifiers.awq import AWQModifier
10+
from llmcompressor.modifiers.quantization import GPTQModifier, QuantizationModifier
11+
from compressed_tensors.quantization import QuantizationArgs, QuantizationScheme, QuantizationType, QuantizationStrategy
12+
13+
W8A8_W_cha_A_ten_static_symmetric = {
14+
"group_0": QuantizationScheme(
15+
targets=["Linear"],
16+
weights=QuantizationArgs(
17+
num_bits=8,
18+
type=QuantizationType.INT,
19+
strategy=QuantizationStrategy.CHANNEL,
20+
symmetric=True,
21+
dynamic=False
22+
),
23+
input_activations=QuantizationArgs(
24+
num_bits=8,
25+
type=QuantizationType.INT,
26+
strategy=QuantizationStrategy.TENSOR,
27+
symmetric=True,
28+
dynamic=False
29+
),
30+
),
31+
}
32+
33+
# supported modifiers
34+
MODIFIER_DICT = {
35+
"PTQ": QuantizationModifier,
36+
"AWQ": AWQModifier,
37+
"GPTQ": GPTQModifier,
38+
}
39+
40+
# supported schemes
41+
SCHEMES_DICT = {
42+
"W8A8_W_cha_A_ten_static_symmetric": W8A8_W_cha_A_ten_static_symmetric,
43+
}
44+
45+
MODEL_DICT = {
46+
"qwen3": AutoModelForCausalLM,
47+
}
48+
49+
TOKENIZER_DICT = {
50+
"qwen3": AutoTokenizer,
51+
}
52+
53+
54+
def load_environment_variables():
55+
env_vars = {
56+
'model_path': "Qwen/Qwen3-32B",
57+
'export_path': "/llm-compressor/export/GPTQ/W8A8_W_cha_A_ten_static_symmetric",
58+
'modifier': "GPTQ",
59+
'schemes': "W8A8_W_cha_A_ten_static_symmetric",
60+
'calib_prompt_path': "HuggingFaceH4/ultrachat_200k"
61+
}
62+
63+
# verify export model path
64+
if env_vars['export_path'] is None:
65+
env_vars['export_path'] = env_vars['model_path'].rstrip("/") + "-" + env_vars['modifier']
66+
if env_vars['schemes'] is not None:
67+
env_vars['export_path'] += "-" + env_vars['schemes']
68+
os.makedirs(env_vars['export_path'], exist_ok=True)
69+
70+
return env_vars
71+
72+
73+
def load_calibration_text_dataset(calib_prompt_path, tokenizer):
74+
# Load dataset
75+
for f in os.listdir(calib_prompt_path):
76+
print(f)
77+
if any(f.lower().endswith('.jsonl') for f in os.listdir(calib_prompt_path)):
78+
ds = load_dataset('json', data_dir=calib_prompt_path, split='validation')
79+
elif any(f.lower().endswith('.parquet') for f in os.listdir(calib_prompt_path)):
80+
ds = load_dataset("parquet", data_dir=calib_prompt_path, split="train[:512]")
81+
else:
82+
raise ValueError("Unsupported calibration file format: {}".format(
83+
calib_prompt_path.split('.')[-1]))
84+
85+
# Preprocess dataset
86+
def preprocess(example):
87+
if tokenizer.chat_template is not None:
88+
return {"text": tokenizer.apply_chat_template(
89+
example["messages"], tokenize=False)}
90+
else:
91+
return {"text": example["messages"]}
92+
93+
# Tokenize inputs
94+
def tokenize(sample):
95+
return tokenizer(
96+
sample["text"],
97+
add_special_tokens=False,
98+
)
99+
100+
ds = ds.map(preprocess)
101+
ds = ds.map(tokenize, remove_columns=ds.column_names)
102+
return ds
103+
104+
105+
# Define a oneshot data collator for multimodal inputs.
106+
def data_collator(batch):
107+
assert len(batch) == 1
108+
return {
109+
key: torch.tensor(value, dtype=torch.bfloat16 if key == "pixel_values" else torch.long)
110+
for key, value in batch[0].items()
111+
}
112+
113+
114+
def quantize_model(model, env_vars, dataset_dict=None):
115+
# since the MoE gate layers are sensitive to quantization, we add them to the ignore
116+
# list so they remain at full precision
117+
ignore = ["lm_head", "re:.*mlp.down_proj"]
118+
119+
# define a llmcompressor recipe
120+
recipe = [
121+
MODIFIER_DICT[env_vars['modifier']](
122+
config_groups=SCHEMES_DICT[env_vars['schemes']],
123+
ignore=ignore,
124+
),
125+
]
126+
127+
# quantize the model
128+
oneshot(
129+
model=model,
130+
dataset=dataset_dict,
131+
recipe=recipe,
132+
trust_remote_code_model=True,
133+
)
134+
135+
136+
def save_quantized_model(model, tokenizer, save_path, save_compressed=False):
137+
model.save_pretrained(save_path, save_compressed=save_compressed)
138+
tokenizer.save_pretrained(save_path)
139+
140+
141+
if __name__ == '__main__':
142+
# get environment variables
143+
env_vars = load_environment_variables()
144+
145+
# support model type list
146+
config = AutoConfig.from_pretrained(env_vars['model_path'], trust_remote_code=True)
147+
model_type = config.model_type
148+
149+
model = MODEL_DICT[model_type].from_pretrained(
150+
env_vars['model_path'], torch_dtype="auto", trust_remote_code=True
151+
)
152+
tokenizer = TOKENIZER_DICT[model_type].from_pretrained(env_vars['model_path'], trust_remote_code=True)
153+
154+
ds = load_calibration_text_dataset(env_vars["calib_prompt_path"], tokenizer)
155+
156+
# Quantize the model
157+
quantize_model(model, env_vars, ds)
158+
159+
# save the quantized model
160+
save_quantized_model(model, tokenizer, env_vars['export_path'], True)
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
from datasets import load_dataset
2+
from transformers import AutoModelForCausalLM, AutoTokenizer
3+
4+
from llmcompressor import oneshot
5+
from llmcompressor.modifiers.quantization import GPTQModifier
6+
from llmcompressor.modifiers.smoothquant import SmoothQuantModifier
7+
from llmcompressor.utils import dispatch_for_generation
8+
9+
# Select model and load it.
10+
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
11+
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto")
12+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
13+
14+
# Select calibration dataset.
15+
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
16+
DATASET_SPLIT = "train_sft"
17+
18+
# Select number of samples. 512 samples is a good place to start.
19+
# Increasing the number of samples can improve accuracy.
20+
NUM_CALIBRATION_SAMPLES = 512
21+
MAX_SEQUENCE_LENGTH = 2048
22+
23+
# Load dataset and preprocess.
24+
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]")
25+
ds = ds.shuffle(seed=42)
26+
27+
28+
def preprocess(example):
29+
return {
30+
"text": tokenizer.apply_chat_template(
31+
example["messages"],
32+
tokenize=False,
33+
)
34+
}
35+
36+
37+
ds = ds.map(preprocess)
38+
39+
40+
# Tokenize inputs.
41+
def tokenize(sample):
42+
return tokenizer(
43+
sample["text"],
44+
padding=False,
45+
max_length=MAX_SEQUENCE_LENGTH,
46+
truncation=True,
47+
add_special_tokens=False,
48+
)
49+
50+
51+
ds = ds.map(tokenize, remove_columns=ds.column_names)
52+
53+
# Configure algorithms. In this case, we:
54+
# * apply SmoothQuant to make the activations easier to quantize
55+
# * quantize the weights to int8 with GPTQ (static per channel)
56+
# * quantize the activations to int8 (dynamic per token)
57+
recipe = [
58+
SmoothQuantModifier(smoothing_strength=0.8),
59+
GPTQModifier(targets="Linear", scheme="W8A8", ignore=["lm_head"]),
60+
]
61+
62+
# Apply algorithms and save to output_dir
63+
oneshot(
64+
model=model,
65+
dataset=ds,
66+
recipe=recipe,
67+
max_seq_length=MAX_SEQUENCE_LENGTH,
68+
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
69+
)
70+
71+
# Confirm generations of the quantized model look sane.
72+
print("\n\n")
73+
print("========== SAMPLE GENERATION ==============")
74+
dispatch_for_generation(model)
75+
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("npu")
76+
output = model.generate(input_ids, max_new_tokens=100)
77+
print(tokenizer.decode(output[0]))
78+
print("==========================================\n\n")
79+
80+
# Save to disk compressed.
81+
SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-W8A8-Dynamic-Per-Token"
82+
model.save_pretrained(SAVE_DIR, save_compressed=True)
83+
tokenizer.save_pretrained(SAVE_DIR)

mypy.ini

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,15 @@ ignore_missing_imports = True
1515
[mypy-lm_eval.*]
1616
ignore_missing_imports = True
1717

18+
[mypy-compressed_tensors.*]
19+
ignore_missing_imports = True
20+
21+
[mypy-datasets.*]
22+
ignore_missing_imports = True
23+
24+
[mypy-llmcompressor.*]
25+
ignore_missing_imports = True
26+
1827
[mypy-msprobe.*]
1928
ignore_missing_imports = True
20-
allow_untyped_imports = True
29+
allow_untyped_imports = True

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ requires = [
2323
"quart",
2424
"numba",
2525
"opencv-python-headless<=4.11.0.86", # Required to avoid numpy version conflict with vllm
26+
"compressed_tensors>=0.11.0"
2627
]
2728
build-backend = "setuptools.build_meta"
2829

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ torchvision
1616
wheel
1717
pandas-stubs
1818
opencv-python-headless<=4.11.0.86 # Required to avoid numpy version conflict with vllm
19+
compressed_tensors>=0.11.0
1920

2021
# requirements for disaggregated prefill
2122
msgpack

0 commit comments

Comments
 (0)