Skip to content
This repository was archived by the owner on May 1, 2025. It is now read-only.

Commit d5db4bb

Browse files
author
Nghi Bui
committed
update model
1 parent 22ae1da commit d5db4bb

File tree

4 files changed

+79
-6
lines changed

4 files changed

+79
-6
lines changed

codetf/models/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,13 @@ def load_model_pipeline(model_name, model_type="base", task="sum",
5151

5252
return model
5353

54+
def load_model_from_path(checkpoint_path, tokenizer_path, model_name, is_eval=True, load_in_8bit=False, load_in_4bit=False):
55+
model_cls = registry.get_model_class(model_name)
56+
model = model_cls.from_custom(checkpoint_path=checkpoint_path, tokenizer_path=tokenizer_path, load_in_8bit=load_in_8bit, load_in_4bit=load_in_4bit)
57+
if is_eval:
58+
model.eval()
59+
60+
return model
5461

5562
class ModelZoo:
5663
def __init__(self, config_files):

codetf/models/base_model.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,17 @@ def from_pretrained(model_class, model_card, load_in_8bit=False, load_in_4bit=Fa
4545
Build a pretrained model from default configuration file, specified by model_type.
4646
"""
4747
model_config = OmegaConf.load(get_abs_path(model_class.MODEL_DICT))[model_card]
48-
model_cls = model_class.load_model_from_config(model_config=model_config, load_in_8bit=load_in_8bit, load_in_4bit=load_in_4bit, weight_sharding=weight_sharding)
48+
model_cls = model_class.load_huggingface_model_from_config(model_config=model_config, load_in_8bit=load_in_8bit, load_in_4bit=load_in_4bit, weight_sharding=weight_sharding)
49+
50+
return model_cls
51+
52+
53+
@classmethod
54+
def from_custom(model_class, checkpoint_path, tokenizer_path, load_in_8bit=False, load_in_4bit=False):
55+
"""
56+
Build a pretrained model from default configuration file, specified by model_type.
57+
"""
58+
model_cls = model_class.load_custom_model(checkpoint_path, tokenizer_path, load_in_8bit=load_in_8bit, load_in_4bit=load_in_4bit)
4959

5060
return model_cls
5161

codetf/models/causal_lm_models/__init__.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def init_tokenizer(cls, model):
2929
return tokenizer
3030

3131
@classmethod
32-
def load_model_from_config(model_class, model_config, load_in_8bit=False, load_in_4bit=False, weight_sharding=False):
32+
def load_huggingface_model_from_config(model_class, model_config, load_in_8bit=False, load_in_4bit=False, weight_sharding=False):
3333
checkpoint = model_config["huggingface_url"]
3434

3535
if load_in_8bit and load_in_4bit:
@@ -79,6 +79,35 @@ def load_model_from_config(model_class, model_config, load_in_8bit=False, load_i
7979
model_config=model_config,
8080
tokenizer=tokenizer
8181
)
82+
83+
@classmethod
84+
def load_custom_model(model_class, checkpoint_path, tokenizer_path, load_in_8bit=False, load_in_4bit=False):
85+
86+
if load_in_8bit and load_in_4bit:
87+
raise ValueError("Only one of load_in_8bit or load_in_4bit can be True. Please choose one.")
88+
89+
if load_in_8bit:
90+
model = AutoModelForCausalLM.from_pretrained(checkpoint_path,
91+
load_in_8bit=load_in_8bit,
92+
low_cpu_mem_usage=True,
93+
device_map="auto")
94+
elif load_in_4bit:
95+
model = AutoModelForCausalLM.from_pretrained(checkpoint_path,
96+
load_in_4bit=load_in_4bit,
97+
low_cpu_mem_usage=True,
98+
device_map="auto")
99+
else:
100+
model = AutoModelForCausalLM.from_pretrained(checkpoint_path,
101+
low_cpu_mem_usage=True,
102+
device_map="auto")
103+
104+
tokenizer = model_class.init_tokenizer(tokenizer_path)
105+
106+
return model_class(
107+
model=model,
108+
model_config=model_config,
109+
tokenizer=tokenizer
110+
)
82111

83112
def forward(self, sources, max_length=512):
84113
encoding = self.tokenizer(sources, return_tensors='pt').to(self.device)

codetf/models/seq2seq_models/__init__.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,15 @@ def init_tokenizer(cls, model):
3030

3131

3232
@classmethod
33-
def load_model_from_config(model_class, model_config, load_in_8bit=False, load_in_4bit=False, weight_sharding=False):
33+
def load_huggingface_model_from_config(model_class, model_config, load_in_8bit=False, load_in_4bit=False, weight_sharding=False):
3434

3535
checkpoint = model_config["huggingface_url"]
3636

3737
if load_in_8bit and load_in_4bit:
3838
raise ValueError("Only one of load_in_8bit or load_in_4bit can be True. Please choose one.")
3939

4040
# This "device" is for the case of CodeT5plus, will be removed in the future
41-
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
41+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
4242
if weight_sharding:
4343
try:
4444
# Try to download and load the json index file
@@ -85,12 +85,10 @@ def load_model_from_config(model_class, model_config, load_in_8bit=False, load_i
8585
else:
8686
if model_config["device_map"]:
8787
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint,
88-
load_in_4bit=load_in_4bit,
8988
low_cpu_mem_usage=True,
9089
device_map=model_config["device_map"], trust_remote_code=model_config["trust_remote_code"])
9190
else:
9291
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint,
93-
load_in_4bit=load_in_4bit,
9492
low_cpu_mem_usage=True,
9593
trust_remote_code=model_config["trust_remote_code"]).to(device)
9694

@@ -103,6 +101,35 @@ def load_model_from_config(model_class, model_config, load_in_8bit=False, load_i
103101
tokenizer=tokenizer
104102
)
105103

104+
@classmethod
105+
def load_custom_model(model_class, checkpoint_path, tokenizer_path, load_in_8bit=False, load_in_4bit=False):
106+
107+
if load_in_8bit and load_in_4bit:
108+
raise ValueError("Only one of load_in_8bit or load_in_4bit can be True. Please choose one.")
109+
110+
if load_in_8bit:
111+
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint_path,
112+
load_in_8bit=load_in_8bit,
113+
low_cpu_mem_usage=True,
114+
device_map="auto")
115+
elif load_in_4bit:
116+
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint_path,
117+
load_in_4bit=load_in_4bit,
118+
low_cpu_mem_usage=True,
119+
device_map="auto")
120+
else:
121+
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint_path,
122+
low_cpu_mem_usage=True,
123+
device_map="auto")
124+
125+
tokenizer = model_class.init_tokenizer(tokenizer_path)
126+
127+
return model_class(
128+
model=model,
129+
model_config=model_config,
130+
tokenizer=tokenizer
131+
)
132+
106133

107134
def forward(self, sources, max_length=512, beam_size=5):
108135
encoding = self.tokenizer(sources, return_tensors='pt').to(self.model.device)

0 commit comments

Comments
 (0)