@@ -30,15 +30,15 @@ def init_tokenizer(cls, model):
3030
3131
3232 @classmethod
33- def load_model_from_config (model_class , model_config , load_in_8bit = False , load_in_4bit = False , weight_sharding = False ):
33+ def load_huggingface_model_from_config (model_class , model_config , load_in_8bit = False , load_in_4bit = False , weight_sharding = False ):
3434
3535 checkpoint = model_config ["huggingface_url" ]
3636
3737 if load_in_8bit and load_in_4bit :
3838 raise ValueError ("Only one of load_in_8bit or load_in_4bit can be True. Please choose one." )
3939
4040 # This "device" is for the case of CodeT5plus, will be removed in the future
41- device = torch .device ("cuda:0 " if torch .cuda .is_available () else "cpu" )
41+ device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
4242 if weight_sharding :
4343 try :
4444 # Try to download and load the json index file
@@ -85,12 +85,10 @@ def load_model_from_config(model_class, model_config, load_in_8bit=False, load_i
8585 else :
8686 if model_config ["device_map" ]:
8787 model = AutoModelForSeq2SeqLM .from_pretrained (checkpoint ,
88- load_in_4bit = load_in_4bit ,
8988 low_cpu_mem_usage = True ,
9089 device_map = model_config ["device_map" ], trust_remote_code = model_config ["trust_remote_code" ])
9190 else :
9291 model = AutoModelForSeq2SeqLM .from_pretrained (checkpoint ,
93- load_in_4bit = load_in_4bit ,
9492 low_cpu_mem_usage = True ,
9593 trust_remote_code = model_config ["trust_remote_code" ]).to (device )
9694
@@ -103,6 +101,35 @@ def load_model_from_config(model_class, model_config, load_in_8bit=False, load_i
103101 tokenizer = tokenizer
104102 )
105103
104+ @classmethod
105+ def load_custom_model (model_class , checkpoint_path , tokenizer_path , load_in_8bit = False , load_in_4bit = False ):
106+
107+ if load_in_8bit and load_in_4bit :
108+ raise ValueError ("Only one of load_in_8bit or load_in_4bit can be True. Please choose one." )
109+
110+ if load_in_8bit :
111+ model = AutoModelForSeq2SeqLM .from_pretrained (checkpoint_path ,
112+ load_in_8bit = load_in_8bit ,
113+ low_cpu_mem_usage = True ,
114+ device_map = "auto" )
115+ elif load_in_4bit :
116+ model = AutoModelForSeq2SeqLM .from_pretrained (checkpoint_path ,
117+ load_in_4bit = load_in_4bit ,
118+ low_cpu_mem_usage = True ,
119+ device_map = "auto" )
120+ else :
121+ model = AutoModelForSeq2SeqLM .from_pretrained (checkpoint_path ,
122+ low_cpu_mem_usage = True ,
123+ device_map = "auto" )
124+
125+ tokenizer = model_class .init_tokenizer (tokenizer_path )
126+
127+ return model_class (
128+ model = model ,
129+ model_config = model_config ,
130+ tokenizer = tokenizer
131+ )
132+
106133
107134 def forward (self , sources , max_length = 512 , beam_size = 5 ):
108135 encoding = self .tokenizer (sources , return_tensors = 'pt' ).to (self .model .device )
0 commit comments