Skip to content
This repository was archived by the owner on Oct 25, 2024. It is now read-only.

Commit 4b50461

Browse files
authored
Update woq code for intel GPU (#1404)
1 parent 1e00f29 commit 4b50461

File tree

4 files changed

+53
-40
lines changed

4 files changed

+53
-40
lines changed

examples/huggingface/pytorch/text-generation/quantization/run_generation_gpu_woq.py

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -50,37 +50,50 @@
5050
parser.add_argument("--tasks", nargs='+', default=["lambada_openai"], type=str, \
5151
help="tasks list for accuracy validation")
5252
# ============WeightOnlyQuant configs===============
53+
parser.add_argument("--bits", type=int, default=4, choices=[4])
5354
parser.add_argument("--woq", action="store_true")
5455
parser.add_argument("--woq_algo", default="RTN", choices=['RTN', 'GPTQ'],
5556
help="Weight-only parameter.")
56-
parser.add_argument("--woq_dtype", type=str, default="int4_fullrange",
57+
parser.add_argument("--weight_dtype", type=str, default="int4_fullrange",
5758
choices=["int4_fullrange"])
58-
parser.add_argument("--woq_group_size", type=int, default=32)
59-
parser.add_argument("--woq_scheme", default="sym")
59+
parser.add_argument("--group_size", type=int, default=32)
60+
parser.add_argument("--scheme", default="sym")
6061
parser.add_argument("--woq_enable_mse_search", action="store_true")
6162
parser.add_argument("--device", default="xpu")
6263
parser.add_argument("--compute_dtype", default="fp16")
64+
# ============GPTQ configs==============
6365
parser.add_argument(
64-
"--gptq_percdamp",
66+
"--desc_act",
67+
action="store_true",
68+
help="Whether to apply the activation order GPTQ heuristic.",
69+
)
70+
parser.add_argument(
71+
"--damp_percent",
6572
type=float,
6673
default=0.01,
6774
help="Percent of the average Hessian diagonal to use for dampening.",
6875
)
6976
parser.add_argument(
70-
"--gptq_block_size",
77+
"--blocksize",
7178
type=int,
7279
default=128,
7380
help="Block size. sub weight matrix size to run GPTQ.",
7481
)
7582
parser.add_argument(
76-
"--gptq_nsamples", type=int, default=128, help="Number of calibration data samples."
83+
"--nsamples", type=int, default=128, help="Number of calibration data samples."
7784
)
7885
parser.add_argument(
7986
"--max_input_length",
8087
type=int,
8188
default=2048,
8289
help="Calibration dataset sequence max length, this should align with your model config",
8390
)
91+
parser.add_argument(
92+
"--static_groups",
93+
action="store_true",
94+
help="Use determined group to do quantization",
95+
)
96+
parser.add_argument("--calib_iters", default=100, type=int, help="Calibration iters.")
8497
# ============BitsAndBytes configs==============
8598
parser.add_argument("--bitsandbytes", action="store_true")
8699
parser.add_argument("--load_in_4bit", type=bool, default=False)
@@ -118,22 +131,22 @@
118131
dataset=args.dataset,
119132
bits=args.bits,
120133
desc_act=args.desc_act,
121-
damp_percent=args.gptq_percdamp,
122-
sym=True if args.woq_scheme == "sym" else False,
123-
blocksize=args.gptq_block_size,
124-
nsamples=args.gptq_nsamples,
134+
damp_percent=args.damp_percent,
135+
sym=True if args.scheme == "sym" else False,
136+
blocksize=args.blocksize,
137+
nsamples=args.nsamples,
125138
static_groups=args.static_groups,
126-
group_size=args.woq_group_size,
139+
group_size=args.group_size,
127140
max_input_length=args.max_input_length,
128141
compute_dtype=args.compute_dtype,
129142
scale_dtype=args.compute_dtype,
130-
weight_dtype=args.woq_dtype,
143+
weight_dtype=args.weight_dtype,
131144
calib_iters=args.calib_iters,
132145
)
133146
else:
134147
quantization_config = RtnConfig(
135-
compute_dtype=args.compute_dtype, weight_dtype=args.woq_dtype,
136-
group_size=args.woq_group_size, scale_dtype=args.compute_dtype
148+
compute_dtype=args.compute_dtype, weight_dtype=args.weight_dtype,
149+
group_size=args.group_size, scale_dtype=args.compute_dtype
137150
) #default is A16W4G16
138151

139152
# get model
@@ -260,16 +273,17 @@
260273
args.model, trust_remote_code=args.trust_remote_code, device_map=args.device, torch_dtype=torch_dtype) \
261274
if user_model is None else user_model
262275
if quantization_config is None:
263-
quantization_config = WeightOnlyQuantConfig.from_pretrained(args.model)
276+
quantization_config = user_model.quantization_config if hasattr(user_model, "quantization_config") else {}
264277
if not args.disable_optimize_transformers:
265278
print("Optimize with IPEX...")
266279
user_model = ipex.optimize_transformers(
267280
user_model.eval(), device=args.device, inplace=True, quantization_config=quantization_config, dtype=torch_dtype)
268281
else:
269282
print("Disabled optimization with IPEX...")
283+
270284
results = evaluate(
271285
model="hf-causal",
272-
model_args='pretrained='+args.model+',tokenizer=' + args.model + \
286+
model_args='pretrained=' + "facebook/opt-125m" +',tokenizer=' + args.model + \
273287
',dtype=float32,trust_remote_code=' + str(args.trust_remote_code),
274288
user_model=user_model,
275289
batch_size=args.batch_size,

intel_extension_for_transformers/transformers/llm/quantization/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,13 @@ def default_calib_func(model):
514514
if config.quant_method.value not in ["awq"]:
515515
calib_func = None
516516

517+
orig_dtype = torch.float32
518+
for param in model.parameters():
519+
orig_dtype = param.dtype
520+
if orig_dtype != torch.float32:
521+
model.to(dtype=torch.float32)
522+
break
523+
517524
inc_model = quantization.fit(
518525
model, conf, calib_func=calib_func, calib_dataloader=calib_dataloader
519526
)
@@ -538,6 +545,8 @@ def default_calib_func(model):
538545
inc_model.model, None, None, config, device=device
539546
)
540547

548+
if orig_dtype != torch.float32:
549+
q_model.to(dtype=orig_dtype)
541550
return q_model.to(device)
542551

543552

intel_extension_for_transformers/transformers/modeling/modeling_auto.py

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,21 @@
6767
convert_to_quantized_model,
6868
replace_linear,
6969
)
70+
from ...tools.utils import get_gpu_family, is_ipex_available
7071
from neural_compressor.adaptor.torch_utils.model_wrapper import WeightOnlyLinear
7172
from transformers.configuration_utils import PretrainedConfig
7273
from transformers import AutoConfig
7374
from transformers.utils import is_accelerate_available, is_bitsandbytes_available
7475
from typing import Union
7576

77+
if is_ipex_available() and get_gpu_family() != "no_gpu":
78+
# pylint: disable=E0401
79+
from intel_extension_for_pytorch.nn.utils._quantize_convert import (
80+
WeightOnlyQuantizedLinear,
81+
)
82+
else:
83+
from ..llm.quantization.nn.modules import QuantizedLinearQBits
84+
7685
torch = LazyImport("torch")
7786

7887

@@ -82,8 +91,6 @@ def recover_export_model(model, current_key_name=None):
8291
8392
Return optimum format model.
8493
"""
85-
from ..llm.quantization.nn.modules import QuantizedLinearQBits
86-
8794
for name, module in model.named_children():
8895
if current_key_name is None:
8996
current_key_name = []
@@ -165,19 +172,15 @@ def build_woq_model(model, quantization_config):
165172

166173
def convert_model_to_public(model):
167174
# reorder weight and scales if they have been transposed
168-
if model.quantization_config.device == "xpu":
169-
# pylint: disable=E0401
170-
from intel_extension_for_pytorch.nn.utils._quantize_convert import (
171-
WeightOnlyQuantizedLinear,
172-
)
173-
175+
if model.device == "xpu":
174176
for name, module in model.named_modules():
175177
if isinstance(module, WeightOnlyQuantizedLinear):
176178
if module.weight_transposed:
177179
module.qweight.data = module.qweight.t_().contiguous()
178180
module.scales.data = module.scales.t_().contiguous()
179181
module.weight_transposed = False
180-
else:
182+
elif model.quantization_config.weight_dtype not in \
183+
["fp8_e5m2", "fp8_e4m3", "nf4", "fp4", "int4_fullrange"]:
181184
model = recover_export_model(model)
182185

183186

@@ -195,14 +198,7 @@ def save_low_bit(
195198
)
196199
return
197200

198-
if self.quantization_config.weight_dtype not in [
199-
"fp8_e5m2",
200-
"fp8_e4m3",
201-
"nf4",
202-
"fp4",
203-
"int4_fullrange",
204-
]:
205-
convert_model_to_public(self)
201+
convert_model_to_public(self)
206202
os.makedirs(save_directory, exist_ok=True)
207203
# use transformers original `save_pretrained` function
208204
del self.save_pretrained
@@ -391,11 +387,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
391387
"quantization_config: {}".format(config.quantization_config)
392388
)
393389
try:
394-
kwargs["device_map"] = (
395-
config.quantization_config["device"]
396-
if "device" in config.quantization_config.keys()
397-
else "auto"
398-
)
399390
model = cls.load_low_bit(
400391
pretrained_model_name_or_path,
401392
*model_args,
@@ -598,7 +589,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
598589
model.config.update({"low_cpu_mem_usage": True})
599590
model.eval()
600591

601-
quantization_config.update(**{"device": "cpu"})
602592
if use_xpu:
603593
import intel_extension_for_pytorch
604594

tests/CI/test_weight_only_gpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def forward(self, x):
7373
return self.linear(x)
7474

7575

76-
@unittest.skipIf(not _ipex_available or gpu_name == "no_gpu",
76+
@unittest.skipIf(not is_ipex_available() or gpu_name == "no_gpu",
7777
"There is no Intel GPU in this machine, skip this test!")
7878
class TestArcWeightOnly(unittest.TestCase):
7979

0 commit comments

Comments
 (0)