11import sys
22from pathlib import Path
33sys .path .append (str (Path ("." ).absolute ().parent ))
4- from transformers import RobertaTokenizer
4+ from transformers import AutoTokenizer
55from codetf .models .base_model import BaseModel
66from transformers import AutoModelForSeq2SeqLM , AutoConfig
77from codetf .common .registry import registry
88from accelerate import Accelerator
9+ import torch
10+ from accelerate import init_empty_weights , load_checkpoint_and_dispatch
11+ from huggingface_hub import hf_hub_download
12+ import torch
913
1014@registry .register_model ("codet5" )
1115class Seq2SeqModel (BaseModel ):
@@ -22,7 +26,7 @@ def __init__(self, model, model_config, tokenizer):
2226
2327 @classmethod
2428 def init_tokenizer (cls , model ):
25- return RobertaTokenizer .from_pretrained (model )
29+ return AutoTokenizer .from_pretrained (model )
2630
2731
2832 @classmethod
@@ -33,28 +37,62 @@ def load_model_from_config(model_class, model_config, load_in_8bit=False, load_i
3337 if load_in_8bit and load_in_4bit :
3438 raise ValueError ("Only one of load_in_8bit or load_in_4bit can be True. Please choose one." )
3539
40+ # This "device" is for the case of CodeT5plus, will be removed in the future
41+ device = torch .device ("cuda:0" if torch .cuda .is_available () else "cpu" )
3642 if weight_sharding :
37- weights_location = hf_hub_download (checkpoint , "pytorch_model.bin" )
43+ try :
44+ # Try to download and load the json index file
45+ weights_location = hf_hub_download (checkpoint , "pytorch_model.bin" )
46+ except Exception :
47+ try :
48+ # If that fails, try to download and load the bin file
49+ weights_location = hf_hub_download (checkpoint , "pytorch_model.bin.index.json" )
50+ except Exception as e :
51+ # If both fail, raise an error
52+ raise Exception (f"Failed to download weights: { str (e )} " )
3853 config = AutoConfig .from_pretrained (checkpoint )
3954 with init_empty_weights ():
4055 model = AutoModelForSeq2SeqLM .from_config (config )
4156
4257 model .tie_weights ()
4358 model = load_checkpoint_and_dispatch (
44- model , weights_location , device_map = "auto" , no_split_module_classes = ["GPTJBlock" ]
59+ model , weights_location , model_config ["device_map" ],
60+ no_split_module_classes = ["GPTJBlock" ]
4561 )
4662 else :
4763 if load_in_8bit :
48- model = AutoModelForSeq2SeqLM .from_pretrained (checkpoint ,
49- load_in_8bit = load_in_8bit ,
50- device_map = "auto" )
64+ if model_config ["device_map" ]:
65+ model = AutoModelForSeq2SeqLM .from_pretrained (checkpoint ,
66+ load_in_8bit = load_in_8bit ,
67+ low_cpu_mem_usage = True ,
68+ device_map = "auto" , trust_remote_code = model_config ["trust_remote_code" ])
69+ else :
70+ model = AutoModelForSeq2SeqLM .from_pretrained (checkpoint ,
71+ load_in_8bit = load_in_8bit ,
72+ low_cpu_mem_usage = True ,
73+ trust_remote_code = model_config ["trust_remote_code" ])
5174 elif load_in_4bit :
52- model = AutoModelForSeq2SeqLM .from_pretrained (checkpoint ,
53- load_in_4bit = load_in_4bit ,
54- device_map = "auto" )
75+ if model_config ["device_map" ]:
76+ model = AutoModelForSeq2SeqLM .from_pretrained (checkpoint ,
77+ load_in_4bit = load_in_4bit ,
78+ low_cpu_mem_usage = True ,
79+ device_map = "auto" , trust_remote_code = model_config ["trust_remote_code" ])
80+ else :
81+ model = AutoModelForSeq2SeqLM .from_pretrained (checkpoint ,
82+ load_in_4bit = load_in_4bit ,
83+ low_cpu_mem_usage = True ,
84+ trust_remote_code = model_config ["trust_remote_code" ])
5585 else :
56- model = AutoModelForSeq2SeqLM .from_pretrained (checkpoint ,
57- device_map = "auto" )
86+ if model_config ["device_map" ]:
87+ model = AutoModelForSeq2SeqLM .from_pretrained (checkpoint ,
88+ load_in_4bit = load_in_4bit ,
89+ low_cpu_mem_usage = True ,
90+ device_map = model_config ["device_map" ], trust_remote_code = model_config ["trust_remote_code" ])
91+ else :
92+ model = AutoModelForSeq2SeqLM .from_pretrained (checkpoint ,
93+ load_in_4bit = load_in_4bit ,
94+ low_cpu_mem_usage = True ,
95+ trust_remote_code = model_config ["trust_remote_code" ]).to (device )
5896
5997
6098 tokenizer = model_class .init_tokenizer (model_config .get ("tokenizer_url" ))
@@ -66,22 +104,20 @@ def load_model_from_config(model_class, model_config, load_in_8bit=False, load_i
66104 )
67105
68106
69- def forward (self , sources ):
70- encoding = self .tokenizer (sources , return_tensors = 'pt' )
71- input_ids = encoding .input_ids .to (self .device )
72- attention_mask = encoding .attention_mask .to (self .device )
73- generated_ids = self .model .generate (input_ids , attention_mask = attention_mask ,
74- max_length = self .max_prediction_length ,
75- num_beams = self .beam_size )
107+ def forward (self , sources , max_length = 512 , beam_size = 5 ):
108+ encoding = self .tokenizer (sources , return_tensors = 'pt' ).to (self .model .device )
109+ encoding ['decoder_input_ids' ] = encoding ['input_ids' ].clone ()
110+ generated_ids = self .model .generate (** encoding ,
111+ max_length = max_length ,
112+ num_beams = beam_size )
76113
77114 predictions = self .tokenizer .batch_decode (generated_ids , skip_special_tokens = True )
78115 return predictions
79116
80117
81- def predict (self , sources ):
118+ def predict (self , sources , max_length = 512 , beam_size = 5 ):
82119
83120 input_for_net = [' ' .join (source .strip ().split ()).replace ('\n ' , ' ' ) for source in sources ]
84- # if self.task in ["sum", "translate", "nl2code", "refine"]:
85- output = self .forward (input_for_net )
121+ output = self .forward (input_for_net , max_length , beam_size )
86122
87123 return output
0 commit comments