1717 f"\033 [0m" )
1818 sys .exit (1 )
1919
20-
21- def run_pytorch_fp32 (model_name , num_runs , timeout , ** kwargs ):
20+ def run_pytorch (model_name , num_runs , timeout , use_torch_fp16 = False ):
2221 import os
2322 import sys
2423 import torch
@@ -32,6 +31,10 @@ def run_pytorch_fp32(model_name, num_runs, timeout, **kwargs):
3231 from speech_recognition .whisper .whisper .whisper .transcribe import transcribe
3332 model = load_model (model_name )
3433 model .eval ()
34+ if use_torch_fp16 :
35+ model = model .half ()
36+ model ._encoder .half ()
37+ model ._decoder .half ()
3538
3639 def single_pass_pytorch (_runner , _librispeech ):
3740 array = _librispeech .get_input_array ()
@@ -40,15 +43,22 @@ def single_pass_pytorch(_runner, _librispeech):
4043 _runner .run (batch_size * array .shape [0 ], audio )["text" ].lstrip ().replace ("." , "" ).upper ()
4144 )
4245
46+ decode_options = {"fp16" : use_torch_fp16 }
47+
4348 def transcribe_wrapper (audio ):
44- return transcribe (model , audio , no_speech_threshold = 1.0 , verbose = None )
49+ return transcribe (model , audio , no_speech_threshold = 1.0 , verbose = None , ** decode_options )
4550
4651 runner = PyTorchRunnerV2 (transcribe_wrapper , throughput_only = True )
4752 librispeech = LibriSpeech ()
4853 print_warning_message ("Sampling rate Whisper operates at is 16,000 Hz, therefore throughput values below can be "
4954 "divided by 16,000 to derive 'seconds of processed audio per second'" )
5055 return run_model (single_pass_pytorch , runner , librispeech , batch_size , num_runs , timeout )
5156
57+ def run_pytorch_fp32 (model_name , num_runs , timeout ):
58+ return run_pytorch (model_name , num_runs , timeout , use_torch_fp16 = False )
59+
60+ def run_pytorch_fp16 (model_name , num_runs , timeout ):
61+ return run_pytorch (model_name , num_runs , timeout , use_torch_fp16 = True )
5262
5363def run_pytorch_cuda (model_name , num_runs , timeout , ** kwargs ):
5464 import os
0 commit comments