5555 WeightOnlyQuantConfig ,
5656 BitsAndBytesConfig
5757)
58+
59+ import shutil
60+
5861if is_deepspeed_available ():
5962 import deepspeed # pylint: disable=E0401
6063
@@ -240,11 +243,19 @@ def import_deepspeed():
240243 logging .info ("DeepSpeed is enabled." )
241244
242245
243- def init_deepspeed_inference (model , model_name_or_path , use_hpu_graphs , is_meta , token = None ):
246+ def init_deepspeed_inference (model , model_name_or_path , peft_path , use_hpu_graphs , is_meta , token = None ):
244247 # Initialize the model
245248 from habana_frameworks .torch .distributed .hccl import initialize_distributed_hpu # pylint: disable=E0401
246249
247250 world_size , rank , local_rank = initialize_distributed_hpu ()
251+ merged_model_dir = None
252+ if peft_path and is_meta :
253+ merged_model_dir = "/tmp/text_generation_merged_peft_model"
254+ if local_rank == 0 :
255+ if Path (merged_model_dir ).is_dir ():
256+ shutil .rmtree (merged_model_dir )
257+ peft_model (model_name_or_path , peft_path , torch .bfloat16 , token ).save_pretrained (merged_model_dir )
258+ torch .distributed .barrier ()
248259
249260 model = model .eval ()
250261 ds_inference_kwargs = {"dtype" : torch .bfloat16 }
@@ -253,7 +264,8 @@ def init_deepspeed_inference(model, model_name_or_path, use_hpu_graphs, is_meta,
253264 # Make sure all devices/nodes have access to the model checkpoints
254265 if is_meta :
255266 checkpoints_json = "checkpoints.json"
256- write_checkpoints_json (model_name_or_path , local_rank , checkpoints_json , token )
267+ write_checkpoints_json (merged_model_dir if merged_model_dir is not None else model_name_or_path , local_rank ,
268+ checkpoints_json , token )
257269
258270 torch .distributed .barrier ()
259271
@@ -264,6 +276,50 @@ def init_deepspeed_inference(model, model_name_or_path, use_hpu_graphs, is_meta,
264276 model = deepspeed .init_inference (model , ** ds_inference_kwargs )
265277 return model .module
266278
279+
280+ def peft_model (model_name , peft_model , model_dtype , hf_access_token = None ):
281+ import importlib .util
282+
283+ if importlib .util .find_spec ("peft" ) is None :
284+ raise ImportError ("The `peft` package is not installed, please run: `pip install peft`." )
285+ from peft import AutoPeftModelForCausalLM
286+ from peft .config import PeftConfigMixin
287+
288+ base_model_name = PeftConfigMixin .from_pretrained (
289+ peft_model ,
290+ use_auth_token = hf_access_token ,
291+ ).base_model_name_or_path
292+
293+ base_model_is_local = Path (base_model_name ).is_dir ()
294+ if not base_model_is_local :
295+ # Check if the base model path to a remote repository on the HF Hub exists
296+ from huggingface_hub import list_repo_files
297+
298+ try :
299+ list_repo_files (base_model_name )
300+ base_model_is_remote = True
301+ except Exception :
302+ base_model_is_remote = False
303+
304+ if base_model_is_local or base_model_is_remote :
305+ model = AutoPeftModelForCausalLM .from_pretrained (peft_model , torch_dtype = model_dtype , low_cpu_mem_usage = True ,
306+ use_auth_token = hf_access_token )
307+ else :
308+ # Since the base model doesn't exist locally nor remotely, use `args.model_name_or_path` as the base model
309+ print (
310+ f"The base model `{ base_model_name } ` of the LoRA configuration associated"
311+ f" to `{ peft_model } ` does not exist locally or remotely. Using "
312+ f"`--model_name_or_path { model_name } ` as a fall back for the base model."
313+ )
314+ from peft import PeftModel
315+
316+ model = AutoModelForCausalLM .from_pretrained (model_name , torch_dtype = model_dtype , low_cpu_mem_usage = True ,
317+ use_auth_token = hf_access_token )
318+ model = PeftModel .from_pretrained (model , peft_model , torch_dtype = model_dtype , low_cpu_mem_usage = True ,
319+ use_auth_token = hf_access_token )
320+
321+ return model .merge_and_unload ()
322+
267323def load_model (
268324 model_name ,
269325 tokenizer_name ,
@@ -376,9 +432,6 @@ def load_model(
376432 logging .info ("Optimized Model loaded." )
377433 return
378434
379- if peft_path and device == "hpu" and use_deepspeed and load_to_meta :
380- logging .warning ("PEFT could not work in deepspeed sharded checkpt loading mode, set load_to_meta to False" )
381- load_to_meta = False
382435 if device == "hpu" and use_deepspeed and load_to_meta :
383436 with deepspeed .OnDevice (dtype = torch .bfloat16 , device = "meta" ):
384437 model = AutoModelForCausalLM .from_config (config , torch_dtype = torch .bfloat16 )
@@ -500,7 +553,7 @@ def load_model(
500553 model .generation_config .eos_token_id = tokenizer .eos_token_id
501554
502555 if device == "hpu" :
503- if peft_path :
556+ if peft_path and not ( use_deepspeed and load_to_meta ) :
504557 from peft import PeftModel
505558 model = PeftModel .from_pretrained (model , peft_path )
506559 model = model .to (torch .bfloat16 )
@@ -516,6 +569,7 @@ def load_model(
516569 model = init_deepspeed_inference (
517570 model = model ,
518571 model_name_or_path = model_name ,
572+ peft_path = peft_path ,
519573 use_hpu_graphs = use_hpu_graphs ,
520574 is_meta = load_to_meta ,
521575 token = hf_access_token ,
0 commit comments