From d8375269b806f4c58dbfe6eeb06e81071a2c0017 Mon Sep 17 00:00:00 2001 From: Karol Kontny Date: Wed, 27 Mar 2024 13:51:19 +0100 Subject: [PATCH 1/3] Adding fp16 mode whisper --- speech_recognition/whisper/run.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/speech_recognition/whisper/run.py b/speech_recognition/whisper/run.py index fca19ccb..3f57345e 100644 --- a/speech_recognition/whisper/run.py +++ b/speech_recognition/whisper/run.py @@ -17,8 +17,7 @@ f"\033[0m") sys.exit(1) - -def run_pytorch_fp32(model_name, num_runs, timeout, **kwargs): +def run_pytorch(model_name, num_runs, timeout, use_torch_fp16=False): import os import sys import torch @@ -32,6 +31,10 @@ def run_pytorch_fp32(model_name, num_runs, timeout, **kwargs): from speech_recognition.whisper.whisper.whisper.transcribe import transcribe model = load_model(model_name) model.eval() + if use_torch_fp16: + model = model.half() + model._encoder.half() + model._decoder.half() def single_pass_pytorch(_runner, _librispeech): array = _librispeech.get_input_array() @@ -40,8 +43,10 @@ def single_pass_pytorch(_runner, _librispeech): _runner.run(batch_size * array.shape[0], audio)["text"].lstrip().replace(".", "").upper() ) + decode_options = {"fp16": use_torch_fp16} + def transcribe_wrapper(audio): - return transcribe(model, audio, no_speech_threshold=1.0, verbose=None) + return transcribe(model, audio, no_speech_threshold=1.0, verbose=None, **decode_options) runner = PyTorchRunnerV2(transcribe_wrapper, throughput_only=True) librispeech = LibriSpeech() @@ -49,6 +54,11 @@ def transcribe_wrapper(audio): "divided by 16,000 to derive 'seconds of processed audio per second'") return run_model(single_pass_pytorch, runner, librispeech, batch_size, num_runs, timeout) +def run_pytorch_fp32(model_name, num_runs, timeout): + return run_pytorch(model_name, num_runs, timeout, use_torch_fp16=False) + +def run_pytorch_fp16(model_name, num_runs, timeout): + return run_pytorch(model_name, num_runs, timeout, use_torch_fp16=True) def run_pytorch_cuda(model_name, num_runs, timeout, **kwargs): import os From 95cef1aa8c1a199a2e991f310403a75cc9746ed5 Mon Sep 17 00:00:00 2001 From: Karol Kontny Date: Tue, 7 May 2024 16:26:55 +0200 Subject: [PATCH 2/3] Fix run.py script --- speech_recognition/whisper/run.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/speech_recognition/whisper/run.py b/speech_recognition/whisper/run.py index 3f57345e..0ccd72b8 100644 --- a/speech_recognition/whisper/run.py +++ b/speech_recognition/whisper/run.py @@ -99,8 +99,12 @@ def transcribe_wrapper(audio): whisper_variants = whisper_variants + [f"{name}.en" for name in whisper_variants[:4]] parser = DefaultArgParser(["pytorch"]) parser.require_model_name(whisper_variants) + parser.add_argument("-p", "--precision", type=str, choices=["fp32", "fp16"], required=True) + args = vars(parser.parse()) if torch.cuda.is_available(): - run_pytorch_cuda(**vars(parser.parse())) - else: - run_pytorch_fp32(**vars(parser.parse())) + run_pytorch_cuda(**args) + elif args["precision"] == "fp32": + run_pytorch_fp32(**args) + elif args["precision"] == "fp16": + run_pytorch_fp16(**args) From eeeee1a27318e55bf62291d24105f62f97b93249 Mon Sep 17 00:00:00 2001 From: Karol Kontny Date: Tue, 7 May 2024 16:33:28 +0200 Subject: [PATCH 3/3] Add blank lines --- speech_recognition/whisper/run.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/speech_recognition/whisper/run.py b/speech_recognition/whisper/run.py index 0ccd72b8..afd34abc 100644 --- a/speech_recognition/whisper/run.py +++ b/speech_recognition/whisper/run.py @@ -17,6 +17,7 @@ f"\033[0m") sys.exit(1) + def run_pytorch(model_name, num_runs, timeout, use_torch_fp16=False): import os import sys @@ -54,12 +55,15 @@ def transcribe_wrapper(audio): "divided by 16,000 to derive 'seconds of processed audio per second'") return run_model(single_pass_pytorch, runner, librispeech, batch_size, num_runs, timeout) + def run_pytorch_fp32(model_name, num_runs, timeout): return run_pytorch(model_name, num_runs, timeout, use_torch_fp16=False) + def run_pytorch_fp16(model_name, num_runs, timeout): return run_pytorch(model_name, num_runs, timeout, use_torch_fp16=True) + def run_pytorch_cuda(model_name, num_runs, timeout, **kwargs): import os import sys