@@ -37,6 +37,7 @@ def _is_auto_round_available():
3737from auto_round .export .export_to_itrex .export import pack_model # pylint: disable=E0401
3838from auto_round .mllm import lmms_eval , mllm_eval
3939from auto_round .mllm .template import Template , get_template
40+ from auto_round .schemes import QuantizationScheme
4041
4142from neural_compressor .torch .algorithms import Quantizer
4243from neural_compressor .torch .utils import get_accelerator , logger
@@ -53,7 +54,7 @@ def __init__(
5354 enable_full_range : bool = False , ##for symmetric, TODO support later
5455 batch_size : int = 8 ,
5556 amp : bool = True ,
56- device : str = None ,
57+ device_map : str = None ,
5758 lr_scheduler = None ,
5859 dataset : Union [str , list , tuple , torch .utils .data .DataLoader ] = "NeelNanda/pile-10k" ,
5960 enable_quanted_input : bool = True ,
@@ -91,6 +92,8 @@ def __init__(
9192 processor = None ,
9293 template : Union [str , Template ] = None ,
9394 truncation : bool = False ,
95+ # 0.7
96+ scheme : Union [str , dict , QuantizationScheme ] = "W4A16" ,
9497 ** kwargs ,
9598 ):
9699 """Init a AutQRoundQuantizer object.
@@ -122,7 +125,7 @@ def __init__(
122125 enable_full_range (bool): Whether to enable full range quantization (default is False).
123126 batch_size (int): Batch size for training (default is 8).
124127 amp (bool): Whether to use automatic mixed precision (default is True).
125- device : The device to be used for tuning (default is "auto" ).
128+ device_map : The device to be used for tuning (default is None ).
126129 lr_scheduler: The learning rate scheduler to be used.
127130 dataset (str): The default dataset name (default is "NeelNanda/pile-10k").
128131 enable_quanted_input (bool): Whether to use the output of the previous quantized block as
@@ -161,6 +164,7 @@ def __init__(
161164 image_processor (Processor): Image processor for special model like llava.
162165 template (Template): The template to specify process for different mllms.
163166 truncation (bool): Activates truncation to cut input sequences longer than `max_length` to `max_length`.
167+ scheme (str| dict | QuantizationScheme ): A preset scheme that defines the quantization configurations.
164168
165169 Returns:
166170 The quantized model.
@@ -205,6 +209,8 @@ def __init__(
205209 self .image_processor = image_processor
206210 self .template = template
207211 self .truncation = truncation
212+ self .scheme = scheme
213+ self .device_map = device_map
208214 self .enable_w4afp8 = self ._is_w4afp8 ()
209215
210216 def _is_w4afp8 (self ):
@@ -237,12 +243,13 @@ def convert(self, model: torch.nn.Module, *args, **kwargs):
237243 rounder = AutoRoundMLLM (
238244 model ,
239245 tokenizer = self .tokenizer ,
246+ scheme = self .scheme ,
240247 processor = self .processor ,
241248 image_processor = self .image_processor ,
242249 layer_config = self .quant_config ,
243250 batch_size = self .batch_size ,
244251 amp = self .amp ,
245- device = self .device ,
252+ device_map = self .device_map ,
246253 lr_scheduler = self .lr_scheduler ,
247254 dataset = dataloader ,
248255 extra_data_dir = self .extra_data_dir ,
@@ -278,12 +285,13 @@ def convert(self, model: torch.nn.Module, *args, **kwargs):
278285 rounder = AutoRound (
279286 model = model ,
280287 tokenizer = self .tokenizer ,
288+ scheme = self .scheme ,
281289 dataset = dataloader ,
282290 layer_config = self .quant_config or {},
283291 enable_full_range = self .enable_full_range ,
284292 batch_size = self .batch_size ,
285293 amp = self .amp ,
286- device = self .device ,
294+ device_map = self .device_map ,
287295 lr_scheduler = self .lr_scheduler ,
288296 enable_quanted_input = self .enable_quanted_input ,
289297 enable_minmax_tuning = self .enable_minmax_tuning ,
@@ -317,7 +325,7 @@ def convert(self, model: torch.nn.Module, *args, **kwargs):
317325 elif "itrex" in self .export_format :
318326 model = pack_model (model , weight_config , device = self .device , inplace = True )
319327 else : # pragma: no cover
320- model = rounder .save_quantized (output_dir = None , format = self .export_format , device = self . device , inplace = True )
328+ model = rounder .save_quantized (output_dir = "temp_auto_round" , format = self .export_format , inplace = True )
321329
322330 return model
323331
@@ -341,9 +349,7 @@ def get_dataloader(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=42
341349 """
342350 from auto_round .calib_dataset import get_dataloader # pylint: disable=E0401
343351
344- dataloader = get_dataloader (
345- tokenizer , seqlen , dataset_name = "NeelNanda/pile-10k" , seed = seed , bs = bs , nsamples = nsamples
346- )
352+ dataloader = get_dataloader (tokenizer , seqlen , dataset_name = dataset_name , seed = seed , bs = bs , nsamples = nsamples )
347353 return dataloader
348354
349355
0 commit comments