@@ -216,6 +216,14 @@ def encode_tokens(tokenizer, string, bos=True, device=default_device):
216216 tokens = [tokenizer .bos_id ()] + tokens
217217 return torch .tensor (tokens , dtype = torch .int , device = device )
218218
219+ def _convert_weight (model ):
220+ from quantize import WeightOnlyInt4Linear
221+ for fqn , mod in model .named_modules ():
222+ if isinstance (mod , WeightOnlyInt4Linear ):
223+ weight = mod .weight .data
224+ weight_int4pack = torch .ops .aten ._convert_weight_to_int4pack (weight , mod .inner_k_tiles )
225+ mod .weight = weight_int4pack
226+
219227def _load_model (checkpoint_path , device , precision , use_tp ):
220228 use_cuda = 'cuda' in device
221229 with torch .device ('meta' ):
@@ -240,19 +248,15 @@ def _load_model(checkpoint_path, device, precision, use_tp):
240248 checkpoint = checkpoint ["model" ]
241249 model .load_state_dict (checkpoint , assign = True )
242250
251+ model = model .to (device = device , dtype = precision )
252+ # int4 packed weight needs to be converted after model loading to the specific device
253+ if "int4" in str (checkpoint_path ):
254+ _convert_weight (model )
255+
243256 if use_tp :
244257 from tp import apply_tp
245258 print ("Applying tensor parallel to model ..." )
246259 apply_tp (model )
247-
248- model = model .to (device = device , dtype = precision )
249- if "int4" in str (checkpoint_path ):
250- from quantize import WeightOnlyInt4Linear
251- for fqn , mod in model .named_modules ():
252- if isinstance (mod , WeightOnlyInt4Linear ):
253- weight = mod .weight .data
254- weight_int4pack = torch .ops .aten ._convert_weight_to_int4pack (weight , mod .inner_k_tiles )
255- mod .weight = weight_int4pack
256260 return model .eval ()
257261
258262def _get_model_size (model ):
0 commit comments