3939}
4040
4141
42- def replace_linear (model , modules_to_not_convert = None , current_key_name = None , quantization_config = None ):
42+ def replace_linear (
43+ model ,
44+ modules_to_not_convert = None ,
45+ current_key_name = None ,
46+ quantization_config = None ,
47+ device = "cpu" ,
48+ empty_weights = False
49+ ):
4350 if modules_to_not_convert is None :
4451 modules_to_not_convert = ["lm_head" ]
4552 if quantization_config .llm_int8_skip_modules :
4653 modules_to_not_convert = modules_to_not_convert .extend (quantization_config .llm_int8_skip_modules )
4754 model , is_replaced = _replace_linear (
48- model , modules_to_not_convert , current_key_name , quantization_config
55+ model , modules_to_not_convert , current_key_name , quantization_config , device = device ,
56+ empty_weights = empty_weights
4957 )
5058
5159 if not is_replaced :
@@ -71,7 +79,13 @@ def convert_dtype_2_str(dtype):
7179
7280
7381def _replace_linear (
74- model , modules_to_not_convert = None , current_key_name = None , quantization_config = None , is_replaced = False
82+ model ,
83+ modules_to_not_convert = None ,
84+ current_key_name = None ,
85+ quantization_config = None ,
86+ is_replaced = False ,
87+ device = "cpu" ,
88+ empty_weights = False
7589):
7690 """
7791 Private method that wraps the recursion for module replacement.
@@ -85,12 +99,25 @@ def _replace_linear(
8599
86100 if isinstance (module , torch .nn .Linear ) and name not in modules_to_not_convert :
87101 # Check if the current key is not in the `modules_to_not_convert`
88- from .nn import QuantizedLinearQBits # TODO: QuantizedLinearINT4, QuantizedLinearINT8
89102 if not any (key in "." .join (current_key_name ) for key in modules_to_not_convert ):
90103 with init_empty_weights ():
91104 in_features = module .in_features
92105 out_features = module .out_features
93-
106+ if device == "cpu" or device == torch .device ("cpu" ):
107+ from .nn .modules import QuantizedLinearQBits # TODO: QuantizedLinearINT4, QuantizedLinearINT8
108+ model ._modules [name ] = QuantizedLinearQBits (
109+ in_features ,
110+ out_features ,
111+ module .bias is not None ,
112+ compute_dtype = quantization_config .compute_dtype ,
113+ compress_statistics = False ,
114+ weight_dtype = quantization_config .weight_dtype ,
115+ scale_dtype = quantization_config .scale_dtype ,
116+ blocksize = quantization_config .group_size ,
117+ scheme = quantization_config .scheme
118+ )
119+ else :
120+ raise Exception ("{} device Unsupport weight only quantization!" .format (device ))
94121 # if quantization_config.quantization_method() == "s8":
95122 # model._modules[name] = QuantizedLinearINT8(
96123 # in_features,
@@ -113,42 +140,36 @@ def _replace_linear(
113140 # scheme=quantization_config.scheme
114141 # )
115142 # is_replaced = True
116- model ._modules [name ] = QuantizedLinearQBits (
117- in_features ,
118- out_features ,
119- module .bias is not None ,
120- compute_dtype = quantization_config .compute_dtype ,
121- compress_statistics = False ,
122- weight_dtype = quantization_config .weight_dtype ,
123- scale_dtype = quantization_config .scale_dtype ,
124- blocksize = quantization_config .group_size ,
125- scheme = quantization_config .scheme
126- )
127143 is_replaced = True
128144 # Store the module class in case we need to transpose the weight later
129145 model ._modules [name ].source_cls = type (module )
130146 # Force requires grad to False to avoid unexpected errors
131147 model ._modules [name ].requires_grad_ (False )
132- model ._modules [name ].set_weights_bias (
133- module .weight .data , None if module .bias is None else module .bias .data
134- )
148+ if not empty_weights :
149+ model ._modules [name ].set_weights_bias (
150+ module .weight .data , None if module .bias is None else module .bias .data
151+ )
152+
135153 if len (list (module .children ())) > 0 :
136154 _ , is_replaced = _replace_linear (
137155 module ,
138156 modules_to_not_convert ,
139157 current_key_name ,
140158 quantization_config ,
141159 is_replaced = is_replaced ,
160+ device = device ,
161+ empty_weights = empty_weights ,
142162 )
143163 # Remove the last key for recursion
144164 current_key_name .pop (- 1 )
145165 return model , is_replaced
146166
147167
148- def convert_to_quantized_model (model , config ):
168+ def convert_to_quantized_model (model , config , device = "cpu" ):
149169 calib_dataloader = config .calib_dataloader
150170 calib_func = config .calib_func
151171 calib_iters = config .calib_iters
172+ model_device = next (model .parameters ()).device
152173 if calib_dataloader is None and config .algorithm in ['TEQ' , 'AWQ' ]:
153174 from datasets import load_dataset
154175 from torch .utils .data import DataLoader
@@ -164,7 +185,7 @@ def convert_to_quantized_model(model, config):
164185 + " from transformer import AutoTokenizer \n "
165186 + " tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) \n "
166187 )
167- exit (0 )
188+ exit (0 )
168189
169190 def tokenize_function (examples ):
170191 if "prompt" in examples :
@@ -218,7 +239,7 @@ def default_calib_func(model):
218239 + "batchsize is 1 and calibration iteration is 100."
219240 )
220241 if config .weight_dtype in ["fp8_e4m3" , "fp8_e5m2" ]:
221- return replace_linear (model , None , None , config )
242+ return replace_linear (model , None , None , config , device = device )
222243 else :
223244 bits = DTYPE_BITS_MAPPING [config .weight_dtype ]
224245 if config .weight_dtype == "int8" :
@@ -253,5 +274,5 @@ def default_calib_func(model):
253274 conf ,
254275 calib_func = calib_func ,
255276 calib_dataloader = calib_dataloader )
256- return replace_linear (inc_model .model , None , None , config )
277+ return replace_linear (inc_model .model , None , None , config , device = device )
257278
0 commit comments