diff --git a/speech_recognition/whisper/run.py b/speech_recognition/whisper/run.py index fca19ccb..afd34abc 100644 --- a/speech_recognition/whisper/run.py +++ b/speech_recognition/whisper/run.py @@ -18,7 +18,7 @@ sys.exit(1) -def run_pytorch_fp32(model_name, num_runs, timeout, **kwargs): +def run_pytorch(model_name, num_runs, timeout, use_torch_fp16=False): import os import sys import torch @@ -32,6 +32,10 @@ def run_pytorch_fp32(model_name, num_runs, timeout, **kwargs): from speech_recognition.whisper.whisper.whisper.transcribe import transcribe model = load_model(model_name) model.eval() + if use_torch_fp16: + model = model.half() + model._encoder.half() + model._decoder.half() def single_pass_pytorch(_runner, _librispeech): array = _librispeech.get_input_array() @@ -40,8 +44,10 @@ def single_pass_pytorch(_runner, _librispeech): _runner.run(batch_size * array.shape[0], audio)["text"].lstrip().replace(".", "").upper() ) + decode_options = {"fp16": use_torch_fp16} + def transcribe_wrapper(audio): - return transcribe(model, audio, no_speech_threshold=1.0, verbose=None) + return transcribe(model, audio, no_speech_threshold=1.0, verbose=None, **decode_options) runner = PyTorchRunnerV2(transcribe_wrapper, throughput_only=True) librispeech = LibriSpeech() @@ -50,6 +56,14 @@ def transcribe_wrapper(audio): return run_model(single_pass_pytorch, runner, librispeech, batch_size, num_runs, timeout) +def run_pytorch_fp32(model_name, num_runs, timeout): + return run_pytorch(model_name, num_runs, timeout, use_torch_fp16=False) + + +def run_pytorch_fp16(model_name, num_runs, timeout): + return run_pytorch(model_name, num_runs, timeout, use_torch_fp16=True) + + def run_pytorch_cuda(model_name, num_runs, timeout, **kwargs): import os import sys @@ -89,8 +103,12 @@ def transcribe_wrapper(audio): whisper_variants = whisper_variants + [f"{name}.en" for name in whisper_variants[:4]] parser = DefaultArgParser(["pytorch"]) parser.require_model_name(whisper_variants) + parser.add_argument("-p", "--precision", type=str, choices=["fp32", "fp16"], required=True) + args = vars(parser.parse()) if torch.cuda.is_available(): - run_pytorch_cuda(**vars(parser.parse())) - else: - run_pytorch_fp32(**vars(parser.parse())) + run_pytorch_cuda(**args) + elif args["precision"] == "fp32": + run_pytorch_fp32(**args) + elif args["precision"] == "fp16": + run_pytorch_fp16(**args)