Skip to content

Commit 5b39d49

Browse files
authored
Add FP16 function to llama2 (#227)
1 parent a5e679f commit 5b39d49

File tree

1 file changed

+5
-1
lines changed
  • natural_language_processing/text_generation/llama2

1 file changed

+5
-1
lines changed

natural_language_processing/text_generation/llama2/run.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from transformers import LlamaForCausalLM, AutoTokenizer
55

66

7-
def run_pytorch(model_name, batch_size, num_runs, timeout, dataset_path):
7+
def run_pytorch(model_name, batch_size, num_runs, timeout, dataset_path, use_torch_fp16=False):
88
def run_single_pass(pytorch_runner, _dataset):
99
input_tensor = tokenizer.encode(_dataset.get_input_string(), return_tensors="pt")
1010
input_tensor = torch.cat([input_tensor for _ in range(batch_size)], 0)
@@ -20,6 +20,8 @@ def run_single_pass(pytorch_runner, _dataset):
2020

2121
model = LlamaForCausalLM.from_pretrained(model_name, torchscript=True)
2222
model.eval()
23+
if use_torch_fp16:
24+
model = model.half()
2325
model.generate = apply_compile(model.generate)
2426

2527
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left')
@@ -37,6 +39,8 @@ def run_single_pass(pytorch_runner, _dataset):
3739
def run_pytorch_fp32(model_name, batch_size, num_runs, timeout, dataset_path, **kwargs):
3840
return run_pytorch(model_name, batch_size, num_runs, timeout, dataset_path)
3941

42+
def run_pytorch_fp16(model_name, batch_size, num_runs, timeout, dataset_path, **kwargs):
43+
return run_pytorch(model_name, batch_size, num_runs, timeout, dataset_path, use_torch_fp16=True)
4044

4145
def main():
4246
from utils.helpers import DefaultArgParser

0 commit comments

Comments
 (0)