44from transformers import LlamaForCausalLM , AutoTokenizer
55
66
7- def run_pytorch (model_name , batch_size , num_runs , timeout , dataset_path ):
7+ def run_pytorch (model_name , batch_size , num_runs , timeout , dataset_path , use_torch_fp16 = False ):
88 def run_single_pass (pytorch_runner , _dataset ):
99 input_tensor = tokenizer .encode (_dataset .get_input_string (), return_tensors = "pt" )
1010 input_tensor = torch .cat ([input_tensor for _ in range (batch_size )], 0 )
@@ -20,6 +20,8 @@ def run_single_pass(pytorch_runner, _dataset):
2020
2121 model = LlamaForCausalLM .from_pretrained (model_name , torchscript = True )
2222 model .eval ()
23+ if use_torch_fp16 :
24+ model = model .half ()
2325 model .generate = apply_compile (model .generate )
2426
2527 tokenizer = AutoTokenizer .from_pretrained (model_name , padding_side = 'left' )
@@ -37,6 +39,8 @@ def run_single_pass(pytorch_runner, _dataset):
3739def run_pytorch_fp32 (model_name , batch_size , num_runs , timeout , dataset_path , ** kwargs ):
3840 return run_pytorch (model_name , batch_size , num_runs , timeout , dataset_path )
3941
42+ def run_pytorch_fp16 (model_name , batch_size , num_runs , timeout , dataset_path , ** kwargs ):
43+ return run_pytorch (model_name , batch_size , num_runs , timeout , dataset_path , use_torch_fp16 = True )
4044
4145def main ():
4246 from utils .helpers import DefaultArgParser
0 commit comments