-
Notifications
You must be signed in to change notification settings - Fork 363
Open
Labels
integrationIssues related to integrations with other libraries, like huggingface, vllm, sglang, gemlite etc.Issues related to integrations with other libraries, like huggingface, vllm, sglang, gemlite etc.
Description
python test_quant_ao.py --model_id meta-llama/Llama-3.1-8B-Instruct --quant_type int4wo
import argparse
import logging
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
logging.basicConfig(level=logging.INFO)
def get_args():
parser = argparse.ArgumentParser(description="Quantize a model for text generation.")
parser.add_argument("--model_id", type=str, required=True, help="The model ID to load from Hugging Face.")
parser.add_argument("--quant_type", type=str, choices=["int8", "int8wo", "w8a8", "nf4", "fp4", "int4", "int4wo"], default="int8", help="Quantization type.")
parser.add_argument("--device", type=str, default="cpu", choices=["cpu", "xpu", "cuda"], help="Device to run the model on.")
return parser.parse_args()
def quantize_model(args):
logging.info(f"Loading model: {args.model_id}")
quantization_config = None
model_kwargs = {}
from torchao.quantization import Int4WeightOnlyConfig, Int8WeightOnlyConfig, Int8DynamicActivationInt8WeightConfig
from torchao.dtypes import Int4CPULayout
if args.quant_type == "int4wo":
quantization_config = TorchAoConfig(Int4WeightOnlyConfig(group_size=128, layout=Int4CPULayout()))
elif args.quant_type == "int8wo":
quantization_config = TorchAoConfig(quant_type=Int8DynamicActivationInt8WeightConfig())
elif args.quant_type == "w8a8":
quantization_config = TorchAoConfig(quant_type=Int8WeightOnlyConfig())
if quantization_config is not None:
model_kwargs["quantization_config"] = quantization_config
model_kwargs["device_map"] = "cpu" if args.device == "cpu" else "auto"
model_kwargs["dtype"] = "auto"
model = AutoModelForCausalLM.from_pretrained(args.model_id, **model_kwargs)
logging.info("Model loaded and quantized successfully.")
return model
if __name__ == "__main__":
args = get_args()
model = quantize_model(args)
tokenizer = AutoTokenizer.from_pretrained(args.model_id)
output_model_name = args.model_id + "_" + args.quant_type
output_model_name = output_model_name.replace("/", "_")
model.save_pretrained(output_model_name, safe_serialization=False)
tokenizer.save_pretrained(output_model_name)output:
File "/opt/venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 124, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^ File "/home/jiqing/transformers/src/transformers/modeling_utils.py", line 670, in _load_state_dict_into_meta_model
hf_quantizer.create_quantized_param(model, param, param_name, param_device)
File "/home/jiqing/transformers/src/transformers/quantizers/quantizer_torchao.py", line 351, in create_quantized_param
quantize_(module, self.quantization_config.get_apply_tensor_subclass())
File "/opt/venv/lib/python3.12/site-packages/torchao/quantization/quant_api.py", line 497, in quantize_
_replace_with_custom_fn_if_matches_filter(
File "/opt/venv/lib/python3.12/site-packages/torchao/quantization/quant_api.py", line 207, in _replace_with_custom_fn_if_matches_filt
er
model = replacement_fn(model, *extra_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/venv/lib/python3.12/site-packages/torchao/quantization/quant_api.py", line 1265, in _int4_weight_only_transform
new_weight = _int4_weight_only_quantize_tensor(module.weight, config)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/venv/lib/python3.12/site-packages/torchao/quantization/quant_api.py", line 1160, in _int4_weight_only_quantize_tensor
new_weight = Int4Tensor.from_hp(
^^^^^^^^^^^^^^^^^^^
File "/opt/venv/lib/python3.12/site-packages/torchao/quantization/quantize_/workflows/int4/int4_tensor.py", line 101, in from_hp
raise ImportError("Requires fbgemm-gpu-genai >= 1.2.0")
ImportError: Requires fbgemm-gpu-genai >= 1.2.0
Only change to version=1 can work. Is there any plan to enable CPU/XPU on version=2? I suppose version=1 will be deprecated in the future.
Metadata
Metadata
Assignees
Labels
integrationIssues related to integrations with other libraries, like huggingface, vllm, sglang, gemlite etc.Issues related to integrations with other libraries, like huggingface, vllm, sglang, gemlite etc.