Skip to content
This repository was archived by the owner on May 1, 2025. It is now read-only.

Commit d088072

Browse files
author
Nghi Bui
committed
update
1 parent 4cd3e6b commit d088072

File tree

7 files changed

+142
-61
lines changed

7 files changed

+142
-61
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ __pycache__/
33
*.py[cod]
44
*$py.class
55

6+
codetf/__pycache__/
67
codetf/.DS_Store
78
assets/.DS_Store
89

codetf/configs/inference/codet5.yaml

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,117 +4,156 @@ codet5-base-multi-sum-pretrained:
44
max_source_length: 512
55
max_prediction_length: 512
66
beam_size: 5
7+
trust_remote_code: False
8+
device_map: "auto"
79
codet5-base-nl2code:
810
huggingface_url: "Salesforce/codet5-base-codexglue-concode"
911
tokenizer_url: "Salesforce/codet5-base"
1012
max_source_length: 512
1113
max_prediction_length: 512
1214
beam_size: 5
15+
trust_remote_code: False
16+
device_map: True
1317
codet5-base-refine:
1418
huggingface_url: "Salesforce/codet5-base-codexglue-refine-medium"
1519
tokenizer_url: "Salesforce/codet5-base"
1620
max_source_length: 512
1721
max_prediction_length: 512
1822
beam_size: 5
23+
trust_remote_code: False
24+
device_map: True
1925
codet5-base-translate_cs_java:
2026
huggingface_url: "Salesforce/codet5-base-codexglue-translate-cs-java"
2127
tokenizer_url: "Salesforce/codet5-base"
2228
max_source_length: 512
2329
max_prediction_length: 512
2430
beam_size: 5
31+
trust_remote_code: False
32+
device_map: True
2533
codet5-base-translate_java_cs:
2634
huggingface_url: "Salesforce/codet5-base-codexglue-translate-java-cs"
2735
tokenizer_url: "Salesforce/codet5-base"
2836
max_source_length: 512
2937
max_prediction_length: 512
3038
beam_size: 5
39+
trust_remote_code: False
40+
device_map: True
3141
codet5-base-sum_python:
3242
huggingface_url: "Salesforce/codet5-base-codexglue-sum-python"
3343
tokenizer_url: "Salesforce/codet5-base"
3444
max_source_length: 512
3545
max_prediction_length: 512
3646
beam_size: 5
47+
trust_remote_code: False
48+
device_map: True
3749
codet5-base-sum_go:
3850
huggingface_url: "Salesforce/codet5-base-codexglue-sum-go"
3951
tokenizer_url: "Salesforce/codet5-base"
4052
max_source_length: 512
4153
max_prediction_length: 512
4254
beam_size: 5
55+
trust_remote_code: False
56+
device_map: True
4357
codet5-base-sum_php:
4458
huggingface_url: "Salesforce/codet5-base-codexglue-sum-php"
4559
tokenizer_url: "Salesforce/codet5-base"
4660
max_source_length: 512
4761
max_prediction_length: 512
4862
beam_size: 5
63+
trust_remote_code: False
64+
device_map: True
4965
codet5-base-sum_javascript:
5066
huggingface_url: "Salesforce/codet5-base-codexglue-sum-javascript"
5167
tokenizer_url: "Salesforce/codet5-base"
5268
max_source_length: 512
5369
max_prediction_length: 512
5470
beam_size: 5
71+
trust_remote_code: False
72+
device_map: True
5573
codet5-base-sum_java:
5674
huggingface_url: "Salesforce/codet5-base-codexglue-sum-java"
5775
tokenizer_url: "Salesforce/codet5-base"
5876
max_source_length: 512
5977
max_prediction_length: 512
6078
beam_size: 5
79+
trust_remote_code: False
80+
device_map: True
6181
codet5-base-sum_ruby:
6282
huggingface_url: "Salesforce/codet5-base-codexglue-sum-ruby"
6383
tokenizer_url: "Salesforce/codet5-base"
6484
max_source_length: 512
6585
max_prediction_length: 512
6686
beam_size: 5
87+
trust_remote_code: False
88+
device_map: True
6789
codet5-base-clone:
6890
huggingface_url: "Salesforce/codet5-base-codexglue-clone"
6991
tokenizer_url: "Salesforce/codet5-base"
7092
max_source_length: 512
7193
max_prediction_length: 512
7294
beam_size: 5
95+
trust_remote_code: False
96+
device_map: True
7397
codet5-base-defect:
7498
huggingface_url: "Salesforce/codet5-base-codexglue-defect"
7599
tokenizer_url: "Salesforce/codet5-base"
76100
max_source_length: 512
77101
max_prediction_length: 512
78102
beam_size: 5
103+
trust_remote_code: False
104+
device_map: True
79105
codet5-plus-instruct-16B-pretrained:
80106
huggingface_url: "Salesforce/instructcodet5p-16b"
81107
tokenizer_url: "Salesforce/instructcodet5p-16b"
82108
max_source_length: 512
83109
max_prediction_length: 512
84110
beam_size: 5
111+
trust_remote_code: True
112+
device_map: False
85113
codet5-plus-16B-pretrained:
86114
huggingface_url: "Salesforce/codet5p-16b"
87115
tokenizer_url: "Salesforce/codet5p-16b"
88116
max_source_length: 512
89117
max_prediction_length: 512
90-
beam_size: 5
118+
trust_remote_code: True
119+
device_map: False
91120
codet5-plus-6B-pretrained:
92121
huggingface_url: "Salesforce/codet5p-6b"
93122
tokenizer_url: "Salesforce/codet5p-6b"
94123
max_source_length: 512
95124
max_prediction_length: 512
96125
beam_size: 5
126+
trust_remote_code: True
127+
device_map: False
97128
codet5-plus-2B-pretrained:
98129
huggingface_url: "Salesforce/codet5p-2b"
99130
tokenizer_url: "Salesforce/codet5p-2b"
100131
max_source_length: 512
101132
max_prediction_length: 512
102133
beam_size: 5
134+
trust_remote_code: True
135+
device_map: False
103136
codet5-plus-770M-python-pretrained:
104137
huggingface_url: "Salesforce/codet5p-770m-py"
105138
tokenizer_url: "Salesforce/codet5p-770m-py"
106139
max_source_length: 512
107140
max_prediction_length: 512
108141
beam_size: 5
142+
trust_remote_code: False
143+
device_map: True
109144
codet5-plus-770M-pretrained:
110145
huggingface_url: "Salesforce/codet5p-770m"
111146
tokenizer_url: "Salesforce/codet5p-770m"
112147
max_source_length: 512
113148
max_prediction_length: 512
114149
beam_size: 5
150+
trust_remote_code: False
151+
device_map: True
115152
codet5-plus-220M-pretrained:
116153
huggingface_url: "Salesforce/codet5p-220m"
117154
tokenizer_url: "Salesforce/codet5p-220m"
118155
max_source_length: 512
119156
max_prediction_length: 512
120-
beam_size: 5
157+
beam_size: 5
158+
trust_remote_code: False
159+
device_map: True

codetf/configs/training/codet5.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ hyperparameters:
33
evaluation_strategy: "epoch"
44
save_strategy: "epoch"
55
logging_strategy: "epoch"
6-
num_train_epochs: 10
6+
num_train_epochs: 1
77
auto_find_batch_size: True
88
batch_size: 4
99
max_steps: 1000
@@ -22,7 +22,7 @@ hyperparameters:
2222
weight_decay: 0.001
2323
run_name: "CodeT5-seq2seq-fine-tuned"
2424
ddp_find_unused_parameters: False
25-
fp16: True
25+
fp16: False
2626
bf16: False
2727
auto_find_batch: True
2828
num_workers: 4

codetf/models/bert_models/__init__.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,17 @@ def load_model_from_config(model_class, model_config, load_in_8bit=False, load_i
3737
raise ValueError("Only one of load_in_8bit or load_in_4bit can be True. Please choose one.")
3838

3939
if weight_sharding:
40-
weights_location = hf_hub_download(checkpoint, "pytorch_model.bin")
40+
try:
41+
# Try to download and load the json index file
42+
weights_location = hf_hub_download(checkpoint, "pytorch_model.bin")
43+
except Exception:
44+
try:
45+
# If that fails, try to download and load the bin file
46+
weights_location = hf_hub_download(checkpoint, "pytorch_model.bin.index.json")
47+
except Exception as e:
48+
# If both fail, raise an error
49+
raise Exception(f"Failed to download weights: {str(e)}")
50+
4151
config = RobertaConfig.from_pretrained(checkpoint)
4252
with init_empty_weights():
4353
model = RobertaModel.from_config(config)

codetf/models/causal_lm_models/__init__.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,17 @@ def load_model_from_config(model_class, model_config, load_in_8bit=False, load_i
3636
raise ValueError("Only one of load_in_8bit or load_in_4bit can be True. Please choose one.")
3737

3838
if weight_sharding:
39-
weights_location = hf_hub_download(checkpoint, "pytorch_model.bin")
39+
try:
40+
# Try to download and load the json index file
41+
weights_location = hf_hub_download(checkpoint, "pytorch_model.bin")
42+
except Exception:
43+
try:
44+
# If that fails, try to download and load the bin file
45+
weights_location = hf_hub_download(checkpoint, "pytorch_model.bin.index.json")
46+
except Exception as e:
47+
# If both fail, raise an error
48+
raise Exception(f"Failed to download weights: {str(e)}")
49+
4050
config = AutoConfig.from_pretrained(checkpoint)
4151
with init_empty_weights():
4252
model = AutoModelForCausalLM.from_config(config)
@@ -49,13 +59,16 @@ def load_model_from_config(model_class, model_config, load_in_8bit=False, load_i
4959
if load_in_8bit:
5060
model = AutoModelForCausalLM.from_pretrained(checkpoint,
5161
load_in_8bit=load_in_8bit,
62+
low_cpu_mem_usage=True,
5263
device_map="auto")
5364
elif load_in_4bit:
5465
model = AutoModelForCausalLM.from_pretrained(checkpoint,
5566
load_in_4bit=load_in_4bit,
67+
low_cpu_mem_usage=True,
5668
device_map="auto")
5769
else:
5870
model = AutoModelForCausalLM.from_pretrained(checkpoint,
71+
low_cpu_mem_usage=True,
5972
device_map="auto")
6073

6174

@@ -67,17 +80,17 @@ def load_model_from_config(model_class, model_config, load_in_8bit=False, load_i
6780
tokenizer=tokenizer
6881
)
6982

70-
def forward(self, sources):
71-
encoding = self.tokenizer(sources, return_tensors='pt')
72-
input_ids = encoding.input_ids.to(self.device)
73-
attention_mask = encoding.attention_mask.to(self.device)
74-
generated_ids = self.model.generate(input_ids, attention_mask=attention_mask,
75-
max_length=self.max_prediction_length)
83+
def forward(self, sources, max_length=512):
84+
encoding = self.tokenizer(sources, return_tensors='pt').to(self.model.device)
85+
# input_ids = encoding.input_ids.to(self.device)
86+
# attention_mask = encoding.attention_mask.to(self.device)
87+
generated_ids = self.model.generate(**encoding,
88+
max_length=max_length)
7689

7790
predictions = self.tokenizer.batch_decode(generated_ids, truncate_before_pattern=[r"\n\n^#", "^'''", "\n\n\n"])
7891
return predictions
7992

80-
def predict(self, sources):
93+
def predict(self, sources, max_length=512):
8194
input_for_net = [' '.join(source.strip().split()).replace('\n', ' ') for source in sources]
82-
output = self.forward(input_for_net)
95+
output = self.forward(input_for_net, max_length=512)
8396
return output
Lines changed: 58 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
import sys
22
from pathlib import Path
33
sys.path.append(str(Path(".").absolute().parent))
4-
from transformers import RobertaTokenizer
4+
from transformers import AutoTokenizer
55
from codetf.models.base_model import BaseModel
66
from transformers import AutoModelForSeq2SeqLM, AutoConfig
77
from codetf.common.registry import registry
88
from accelerate import Accelerator
9+
import torch
10+
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
11+
from huggingface_hub import hf_hub_download
12+
import torch
913

1014
@registry.register_model("codet5")
1115
class Seq2SeqModel(BaseModel):
@@ -22,7 +26,7 @@ def __init__(self, model, model_config, tokenizer):
2226

2327
@classmethod
2428
def init_tokenizer(cls, model):
25-
return RobertaTokenizer.from_pretrained(model)
29+
return AutoTokenizer.from_pretrained(model)
2630

2731

2832
@classmethod
@@ -33,28 +37,62 @@ def load_model_from_config(model_class, model_config, load_in_8bit=False, load_i
3337
if load_in_8bit and load_in_4bit:
3438
raise ValueError("Only one of load_in_8bit or load_in_4bit can be True. Please choose one.")
3539

40+
# This "device" is for the case of CodeT5plus, will be removed in the future
41+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
3642
if weight_sharding:
37-
weights_location = hf_hub_download(checkpoint, "pytorch_model.bin")
43+
try:
44+
# Try to download and load the json index file
45+
weights_location = hf_hub_download(checkpoint, "pytorch_model.bin")
46+
except Exception:
47+
try:
48+
# If that fails, try to download and load the bin file
49+
weights_location = hf_hub_download(checkpoint, "pytorch_model.bin.index.json")
50+
except Exception as e:
51+
# If both fail, raise an error
52+
raise Exception(f"Failed to download weights: {str(e)}")
3853
config = AutoConfig.from_pretrained(checkpoint)
3954
with init_empty_weights():
4055
model = AutoModelForSeq2SeqLM.from_config(config)
4156

4257
model.tie_weights()
4358
model = load_checkpoint_and_dispatch(
44-
model, weights_location, device_map="auto", no_split_module_classes=["GPTJBlock"]
59+
model, weights_location, model_config["device_map"],
60+
no_split_module_classes=["GPTJBlock"]
4561
)
4662
else:
4763
if load_in_8bit:
48-
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint,
49-
load_in_8bit=load_in_8bit,
50-
device_map="auto")
64+
if model_config["device_map"]:
65+
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint,
66+
load_in_8bit=load_in_8bit,
67+
low_cpu_mem_usage=True,
68+
device_map="auto", trust_remote_code=model_config["trust_remote_code"])
69+
else:
70+
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint,
71+
load_in_8bit=load_in_8bit,
72+
low_cpu_mem_usage=True,
73+
trust_remote_code=model_config["trust_remote_code"])
5174
elif load_in_4bit:
52-
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint,
53-
load_in_4bit=load_in_4bit,
54-
device_map="auto")
75+
if model_config["device_map"]:
76+
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint,
77+
load_in_4bit=load_in_4bit,
78+
low_cpu_mem_usage=True,
79+
device_map="auto", trust_remote_code=model_config["trust_remote_code"])
80+
else:
81+
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint,
82+
load_in_4bit=load_in_4bit,
83+
low_cpu_mem_usage=True,
84+
trust_remote_code=model_config["trust_remote_code"])
5585
else:
56-
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint,
57-
device_map="auto")
86+
if model_config["device_map"]:
87+
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint,
88+
load_in_4bit=load_in_4bit,
89+
low_cpu_mem_usage=True,
90+
device_map=model_config["device_map"], trust_remote_code=model_config["trust_remote_code"])
91+
else:
92+
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint,
93+
load_in_4bit=load_in_4bit,
94+
low_cpu_mem_usage=True,
95+
trust_remote_code=model_config["trust_remote_code"]).to(device)
5896

5997

6098
tokenizer = model_class.init_tokenizer(model_config.get("tokenizer_url"))
@@ -66,22 +104,20 @@ def load_model_from_config(model_class, model_config, load_in_8bit=False, load_i
66104
)
67105

68106

69-
def forward(self, sources):
70-
encoding = self.tokenizer(sources, return_tensors='pt')
71-
input_ids = encoding.input_ids.to(self.device)
72-
attention_mask = encoding.attention_mask.to(self.device)
73-
generated_ids = self.model.generate(input_ids, attention_mask=attention_mask,
74-
max_length=self.max_prediction_length,
75-
num_beams=self.beam_size)
107+
def forward(self, sources, max_length=512, beam_size=5):
108+
encoding = self.tokenizer(sources, return_tensors='pt').to(self.model.device)
109+
encoding['decoder_input_ids'] = encoding['input_ids'].clone()
110+
generated_ids = self.model.generate(**encoding,
111+
max_length=max_length,
112+
num_beams=beam_size)
76113

77114
predictions = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
78115
return predictions
79116

80117

81-
def predict(self, sources):
118+
def predict(self, sources, max_length=512, beam_size=5):
82119

83120
input_for_net = [' '.join(source.strip().split()).replace('\n', ' ') for source in sources]
84-
# if self.task in ["sum", "translate", "nl2code", "refine"]:
85-
output = self.forward(input_for_net)
121+
output = self.forward(input_for_net, max_length, beam_size)
86122

87123
return output

0 commit comments

Comments
 (0)