Skip to content
This repository was archived by the owner on May 1, 2025. It is now read-only.

Commit 8bdca54

Browse files
author
Nghi Bui
committed
samples for new version
1 parent 3f5da9c commit 8bdca54

File tree

6 files changed

+39
-37
lines changed

6 files changed

+39
-37
lines changed

codetf/models/causal_lm_models/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,13 @@ def load_model_from_config(model_class, model_config, load_in_8bit=False, load_i
8181
)
8282

8383
def forward(self, sources, max_length=512):
84-
encoding = self.tokenizer(sources, return_tensors='pt').to(self.model.device)
84+
encoding = self.tokenizer(sources, return_tensors='pt').to(self.device)
8585
# input_ids = encoding.input_ids.to(self.device)
8686
# attention_mask = encoding.attention_mask.to(self.device)
8787
generated_ids = self.model.generate(**encoding,
8888
max_length=max_length)
8989

90+
print(generated_ids)
9091
predictions = self.tokenizer.batch_decode(generated_ids, truncate_before_pattern=[r"\n\n^#", "^'''", "\n\n\n"])
9192
return predictions
9293

requirements.txt

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,23 @@
11
# Automatically generated by https://github.com/damnever/pigar.
22

33
accelerate==0.20.3
4-
datasets==2.12.0
5-
evaluate==0.4.0
6-
huggingface-hub==0.14.1
4+
datasets==2.13.1
5+
huggingface-hub==0.15.1
76
iopath==0.1.10
87
nltk==3.8.1
9-
numpy==1.21.6
8+
numpy==1.25.0
109
omegaconf==2.3.0
11-
pandas==1.3.5
12-
peft==0.4.0.dev0
10+
pandas==2.0.2
11+
peft==0.3.0
1312
pyparsing==3.0.7
1413
PyYAML==6.0
15-
requests==2.31.0
14+
requests==2.27.1
1615
rouge-score==0.1.2
1716
sacrebleu==2.3.1
18-
salesforce-codetf==1.0.1.1
19-
scikit-learn==1.0.2
20-
torch==1.13.1
21-
torchvision==0.14.1
17+
scikit-learn==1.2.2
18+
torch==2.0.1
19+
torchvision==0.15.2
2220
tqdm==4.63.0
2321
transformers==4.30.2
2422
tree-sitter==0.20.1
23+
bitsandbytes==0.39.1

setup.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,25 +2,27 @@
22
import platform
33

44
install_requires = [
5-
"datasets==2.12.0",
6-
"evaluate==0.4.0",
7-
"huggingface-hub==0.14.1",
8-
"iopath==0.1.10",
9-
"nltk==3.8.1",
10-
"numpy==1.22.0",
11-
"omegaconf==2.3.0",
12-
"pandas==1.3.5",
13-
"pyparsing==3.0.7",
14-
"PyYAML==6.0",
15-
"requests==2.31.0",
16-
"rouge-score==0.1.2",
17-
"sacrebleu==2.3.1",
18-
"scikit-learn==1.0.2",
19-
"torch==1.13.1",
20-
"torchvision==0.14.1",
21-
"tqdm==4.63.0",
22-
"tree-sitter==0.20.1",
23-
"bitsandbytes==0.39.0"
5+
"accelerate==0.20.3",
6+
"datasets==2.13.1",
7+
"huggingface-hub==0.15.1",
8+
"iopath==0.1.10",
9+
"nltk==3.8.1",
10+
"numpy==1.25.0",
11+
"omegaconf==2.3.0",
12+
"pandas==2.0.2",
13+
"peft==0.3.0",
14+
"pyparsing==3.0.7",
15+
"PyYAML==6.0",
16+
"requests==2.27.1",
17+
"rouge-score==0.1.2",
18+
"sacrebleu==2.3.1",
19+
"scikit-learn==1.2.2",
20+
"torch==2.0.1",
21+
"torchvision==0.15.2",
22+
"tqdm==4.63.0",
23+
"transformers==4.30.2",
24+
"tree-sitter==0.20.1",
25+
"bitsandbytes==0.39.1"
2426
]
2527

2628
DEPENDENCY_LINKS = []

test_inference/test_codegen_nl2code.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
from codetf.models import load_model_pipeline
55

66
code_generation_model = load_model_pipeline(model_name="causallm", task="pretrained",
7-
model_type="codegen-350M-mono", is_eval=True,
7+
model_type="codegen-2B-mono", is_eval=True,
88
load_in_8bit=True, load_in_4bit=False, weight_sharding=False)
99

10-
result = code_generation_model.predict(["def print_hello_world():"])
10+
result = code_generation_model.predict(["# this function prints hello world"])
1111
print(result)

test_inference/test_codet5plus_nl2code.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
from codetf.models import load_model_pipeline
55

66
code_generation_model = load_model_pipeline(model_name="codet5", task="pretrained",
7-
model_type="plus-770M-python", is_eval=True,
8-
load_in_8bit=True, load_in_4bit=False, weight_sharding=False)
7+
model_type="plus-2B", is_eval=True,
8+
load_in_8bit=False, load_in_4bit=False, weight_sharding=True)
99

10-
result = code_generation_model.predict(["def print_hello_world():"])
10+
result = code_generation_model.predict(["def print_hello_world():"], max_length=15)
1111
print(result)

test_inference/test_starcoder_nl2code.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
model_type="starcoder-15.5B", is_eval=True,
99
load_in_8bit=True, weight_sharding=False)
1010

11-
prompts = "def print_hello_world():"
11+
prompts = "# this function prints hello world"
1212
code_snippets = model.predict([prompts])
1313

1414
print(code_snippets)

0 commit comments

Comments
 (0)