Skip to content

Commit 650ef69

Browse files
Whisper HF (#241)
1 parent 67dc024 commit 650ef69

File tree

7 files changed

+97
-31
lines changed

7 files changed

+97
-31
lines changed

.github/workflows/test.yml

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -81,24 +81,24 @@ jobs:
8181
tar -xf aio_objdet_dataset.tar.gz > /dev/null
8282
8383
wget $S3_URL_RESNET_50_V15_TF_FP32 > /dev/null 2>&1
84-
python3 computer_vision/classification/resnet_50_v15/run.py -m resnet_50_v15_tf_fp32.pb -p fp32 -f tf --timeout=60
84+
IGNORE_DATASET_LIMITS=1 python3 computer_vision/classification/resnet_50_v15/run.py -m resnet_50_v15_tf_fp32.pb -p fp32 -f tf --timeout=60
8585
86-
python3 computer_vision/classification/mobilenet_v2/run.py -p fp32 -f pytorch --timeout=60
86+
IGNORE_DATASET_LIMITS=1 python3 computer_vision/classification/mobilenet_v2/run.py -p fp32 -f pytorch --timeout=60
8787
8888
wget https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n.pt > /dev/null 2>&1
89-
python3 computer_vision/object_detection/yolo_v8/run.py -m yolov8n.pt -f pytorch -p fp32 --timeout=60
89+
IGNORE_DATASET_LIMITS=1 python3 computer_vision/object_detection/yolo_v8/run.py -m yolov8n.pt -f pytorch -p fp32 --timeout=60
9090
9191
python3 speech_recognition/whisper/run.py -m small.en
9292
9393
wget $S3_URL_SSD_INCEPTION_V2_TF_FP32 > /dev/null 2>&1
94-
python3 computer_vision/object_detection/ssd_inception_v2/run.py -m ssd_inception_v2_tf_fp32.pb -p fp32 --timeout=60
94+
IGNORE_DATASET_LIMITS=1 python3 computer_vision/object_detection/ssd_inception_v2/run.py -m ssd_inception_v2_tf_fp32.pb -p fp32 --timeout=60
9595
9696
wget https://zenodo.org/records/4735647/files/resnet50_v1.onnx > /dev/null 2>&1
97-
python3 computer_vision/classification/resnet_50_v1/run.py -m resnet50_v1.onnx -p fp32 -f ort
97+
IGNORE_DATASET_LIMITS=1 python3 computer_vision/classification/resnet_50_v1/run.py -m resnet50_v1.onnx -p fp32 -f ort
9898
9999
wget https://s3.amazonaws.com/onnx-model-zoo/vgg/vgg16/vgg16.tar.gz > /dev/null 2>&1
100100
tar -xf vgg16.tar.gz > /dev/null
101-
python3 computer_vision/classification/vgg_16/run.py -m vgg16/vgg16.onnx -p fp32 -f ort
101+
IGNORE_DATASET_LIMITS=1 python3 computer_vision/classification/vgg_16/run.py -m vgg16/vgg16.onnx -p fp32 -f ort
102102
103103
test_arm64:
104104
runs-on: self-hosted
@@ -145,24 +145,24 @@ jobs:
145145
tar -xf aio_objdet_dataset.tar.gz > /dev/null
146146
147147
wget $S3_URL_RESNET_50_V15_TF_FP32 > /dev/null 2>&1
148-
python3 computer_vision/classification/resnet_50_v15/run.py -m resnet_50_v15_tf_fp32.pb -p fp32 -f tf --timeout=60
148+
IGNORE_DATASET_LIMITS=1 python3 computer_vision/classification/resnet_50_v15/run.py -m resnet_50_v15_tf_fp32.pb -p fp32 -f tf --timeout=60
149149
150-
python3 computer_vision/classification/mobilenet_v2/run.py -p fp32 -f pytorch --timeout=60
150+
IGNORE_DATASET_LIMITS=1 python3 computer_vision/classification/mobilenet_v2/run.py -p fp32 -f pytorch --timeout=60
151151
152152
wget https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n.pt > /dev/null 2>&1
153-
python3 computer_vision/object_detection/yolo_v8/run.py -m yolov8n.pt -f pytorch -p fp32 --timeout=60
153+
IGNORE_DATASET_LIMITS=1 python3 computer_vision/object_detection/yolo_v8/run.py -m yolov8n.pt -f pytorch -p fp32 --timeout=60
154154
155155
python3 speech_recognition/whisper/run.py -m small.en
156156
157157
wget $S3_URL_SSD_INCEPTION_V2_TF_FP32 > /dev/null 2>&1
158-
python3 computer_vision/object_detection/ssd_inception_v2/run.py -m ssd_inception_v2_tf_fp32.pb -p fp32 --timeout=60
158+
IGNORE_DATASET_LIMITS=1 python3 computer_vision/object_detection/ssd_inception_v2/run.py -m ssd_inception_v2_tf_fp32.pb -p fp32 --timeout=60
159159
160160
wget https://zenodo.org/records/4735647/files/resnet50_v1.onnx > /dev/null 2>&1
161-
python3 computer_vision/classification/resnet_50_v1/run.py -m resnet50_v1.onnx -p fp32 -f ort
161+
IGNORE_DATASET_LIMITS=1 python3 computer_vision/classification/resnet_50_v1/run.py -m resnet50_v1.onnx -p fp32 -f ort
162162
163163
wget https://s3.amazonaws.com/onnx-model-zoo/vgg/vgg16/vgg16.tar.gz > /dev/null 2>&1
164164
tar -xf vgg16.tar.gz > /dev/null
165-
python3 computer_vision/classification/vgg_16/run.py -m vgg16/vgg16.onnx -p fp32 -f ort
165+
IGNORE_DATASET_LIMITS=1 python3 computer_vision/classification/vgg_16/run.py -m vgg16/vgg16.onnx -p fp32 -f ort
166166
167167
test_pytorch_arm64_sh:
168168
runs-on: self-hosted
@@ -260,10 +260,10 @@ jobs:
260260
261261
AIO_IMPLICIT_FP16_TRANSFORM_FILTER=".*" python3 speech_recognition/whisper/run.py -m tiny.en
262262
263-
python3 computer_vision/classification/mobilenet_v2/run.py -p fp32 -f pytorch --timeout=60
263+
IGNORE_DATASET_LIMITS=1 python3 computer_vision/classification/mobilenet_v2/run.py -p fp32 -f pytorch --timeout=60
264264
265265
wget https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8l.pt > /dev/null 2>&1
266-
AIO_IMPLICIT_FP16_TRANSFORM_FILTER=".*" python3 computer_vision/object_detection/yolo_v8/run.py -m yolov8l.pt -p fp32 -f pytorch
266+
IGNORE_DATASET_LIMITS=1 AIO_IMPLICIT_FP16_TRANSFORM_FILTER=".*" python3 computer_vision/object_detection/yolo_v8/run.py -m yolov8l.pt -p fp32 -f pytorch
267267
268268
wget -O bert_large_mlperf.pt https://zenodo.org/records/3733896/files/model.pytorch?download=1 > /dev/null 2>&1
269269
AIO_IMPLICIT_FP16_TRANSFORM_FILTER=".*" python3 natural_language_processing/extractive_question_answering/bert_large/run_mlperf.py -m bert_large_mlperf.pt -p fp32 -f pytorch
@@ -346,8 +346,8 @@ jobs:
346346
tar -xvf aio_objdet_dataset.tar.gz > /dev/null
347347
348348
wget https://zenodo.org/records/4735647/files/resnet50_v1.onnx > /dev/null 2>&1
349-
AIO_IMPLICIT_FP16_TRANSFORM_FILTER=".*" python3 computer_vision/classification/resnet_50_v1/run.py -m resnet50_v1.onnx -p fp32 -f ort
349+
IGNORE_DATASET_LIMITS=1 AIO_IMPLICIT_FP16_TRANSFORM_FILTER=".*" python3 computer_vision/classification/resnet_50_v1/run.py -m resnet50_v1.onnx -p fp32 -f ort
350350
351351
wget https://s3.amazonaws.com/onnx-model-zoo/vgg/vgg16/vgg16.tar.gz > /dev/null 2>&1
352352
tar -xf vgg16.tar.gz > /dev/null
353-
AIO_IMPLICIT_FP16_TRANSFORM_FILTER=".*" python3 computer_vision/classification/vgg_16/run.py -m vgg16/vgg16.onnx -p fp32 -f ort
353+
IGNORE_DATASET_LIMITS=1 AIO_IMPLICIT_FP16_TRANSFORM_FILTER=".*" python3 computer_vision/classification/vgg_16/run.py -m vgg16/vgg16.onnx -p fp32 -f ort

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
utils/torch_jit_cache
2+
=*
13
.DS_Store
24
.idea/
35
.setup_completed
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# Copyright (c) 2024, Ampere Computing LLC
3+
4+
TORCH_JIT_TRACE = False # otherwise, run torch.compile()
5+
6+
7+
def run_pytorch_fp32(model_name, batch_size, num_runs, timeout, **kwargs):
8+
from utils.benchmark import run_model
9+
from utils.misc import print_warning_message
10+
from utils.pytorch import PyTorchRunnerV2, apply_compile, apply_jit_trace_module
11+
from utils.speech_recognition.libri_speech_v2 import LibriSpeech
12+
from transformers import WhisperProcessor, WhisperForConditionalGeneration
13+
processor = WhisperProcessor.from_pretrained(model_name)
14+
model = WhisperForConditionalGeneration.from_pretrained(model_name, torchscript=TORCH_JIT_TRACE)
15+
model.eval()
16+
librispeech = LibriSpeech()
17+
if TORCH_JIT_TRACE:
18+
waveform = [librispeech.get_input_array() for _ in range(batch_size)]
19+
input_features = processor(
20+
waveform, sampling_rate=LibriSpeech.sampling_rate, return_tensors="pt").input_features
21+
model = apply_jit_trace_module(model, {"generate": input_features})
22+
librispeech = LibriSpeech() # reset
23+
model = model.generate
24+
else:
25+
model = apply_compile(model.generate)
26+
27+
def single_pass_pytorch(_runner, _librispeech):
28+
waveform = [_librispeech.get_input_array() for _ in range(batch_size)]
29+
input_features = processor(
30+
waveform, sampling_rate=LibriSpeech.sampling_rate, return_tensors="pt").input_features
31+
predicted_ids = _runner.run(sum([x.shape[0] for x in waveform]), input_features)
32+
decoded_output = processor.batch_decode(predicted_ids, skip_special_tokens=True)
33+
for i in range(batch_size):
34+
_librispeech.submit_transcription(decoded_output[i].lstrip().replace(",", "").replace(".", "").upper())
35+
36+
runner = PyTorchRunnerV2(model, throughput_only=True)
37+
print_warning_message("Sampling rate Whisper operates at is 16,000 Hz, therefore throughput values below can be "
38+
"divided by 16,000 to derive 'seconds of processed audio per second'")
39+
return run_model(single_pass_pytorch, runner, librispeech, batch_size, num_runs, timeout)
40+
41+
42+
if __name__ == "__main__":
43+
from utils.helpers import DefaultArgParser
44+
whisper_variants = ["openai/whisper-tiny", "openai/whisper-base", "openai/whisper-small", "openai/whisper-medium",
45+
"openai/whisper-large", "openai/whisper-large-v2", "openai/whisper-large-v3"]
46+
whisper_variants = whisper_variants + [f"{name}.en" for name in whisper_variants[:4]]
47+
parser = DefaultArgParser(["pytorch"])
48+
parser.require_model_name(whisper_variants)
49+
parser.ask_for_batch_size(1)
50+
run_pytorch_fp32(**vars(parser.parse()))

tests/test_pytorch_models.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,24 +93,35 @@ def wrapper(**kwargs):
9393

9494
class Whisper(unittest.TestCase):
9595
def setUp(self):
96-
from speech_recognition.whisper.run import run_pytorch_fp32
96+
def wrapper_openai(**kwargs):
97+
from speech_recognition.whisper.run import run_pytorch_fp32
98+
kwargs["q"].put(run_pytorch_fp32(**kwargs)[0])
9799

98-
def wrapper(**kwargs):
100+
def wrapper_hf(**kwargs):
101+
from speech_recognition.whisper.run_hf import run_pytorch_fp32
99102
kwargs["q"].put(run_pytorch_fp32(**kwargs)[0])
100103

101-
self.wrapper = wrapper
104+
self.wrapper_openai = wrapper_openai
105+
self.wrapper_hf = wrapper_hf
102106

103107
@unittest.skipIf(psutil.virtual_memory().available / 1024 ** 3 < 50, "too little memory")
104108
def test_whisper_tiny_en(self):
105109
wer_ref = 0.155
106-
acc = run_process(self.wrapper, {"model_name": "tiny.en", "num_runs": 30, "timeout": None})
110+
acc = run_process(self.wrapper_openai, {"model_name": "tiny.en", "num_runs": 30, "timeout": None})
111+
self.assertTrue(wer_ref / acc["wer_score"] > 0.95)
112+
113+
@unittest.skipIf(psutil.virtual_memory().available / 1024 ** 3 < 50, "too little memory")
114+
def test_whisper_hf_tiny_en(self):
115+
wer_ref = 0.111
116+
acc = run_process(self.wrapper_hf, {"model_name": "openai/whisper-tiny.en", "num_runs": 18,
117+
"batch_size": 4, "timeout": None})
107118
self.assertTrue(wer_ref / acc["wer_score"] > 0.95)
108119

109120
@unittest.skipIf(psutil.virtual_memory().available / 1024 ** 3 < 100, "too little memory")
110121
@unittest.skipUnless('_aio_profiler_print' in dir(torch._C), "too slow to run with native")
111122
def test_whisper_large(self):
112123
wer_ref = 0.124
113-
acc = run_process(self.wrapper, {"model_name": "large", "num_runs": 30, "timeout": None})
124+
acc = run_process(self.wrapper_openai, {"model_name": "large", "num_runs": 30, "timeout": None})
114125
self.assertTrue(wer_ref / acc["wer_score"] > 0.95)
115126

116127

utils/benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
warnings.filterwarnings("ignore")
1616

17-
WARM_UP_RUNS = 3
17+
WARM_UP_RUNS = 9
1818
intra_op_parallelism_threads = None
1919

2020

utils/pytorch.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,12 @@ def apply_jit_trace(model, example_inputs):
248248
return load_from_cache_or_apply(model, lambda: torch.jit.trace(model, example_inputs))
249249

250250

251+
def apply_jit_trace_module(model, example_inputs):
252+
return load_from_cache_or_apply(model, lambda: torch.jit.trace_module(model, example_inputs))
253+
254+
251255
def apply_compile(model):
256+
torch._dynamo.config.cache_size_limit = 512
252257
if os.environ.get("TORCH_COMPILE") == "0":
253258
return model
254259
if version.parse(pkg_resources.get_distribution("torch").version) >= version.parse("1.14"):
@@ -264,11 +269,8 @@ def apply_compile(model):
264269
options = {}
265270
utils.print_warning_message(
266271
f"AIO unavailable or disabled, applying torch.compile() with \"{backend}\" backend.")
267-
return torch.compile(
268-
model,
269-
backend=backend,
270-
options=options
271-
)
272+
model = torch.compile(model, backend=backend, options=options)
273+
return model
272274
else:
273275
utils.print_goodbye_message_and_die(
274276
f"Installed PyTorch version is {pkg_resources.get_distribution('torch').version}. "

utils/speech_recognition/libri_speech_v2.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@ def __init__(self):
1818

1919
def get_input_array(self):
2020
try:
21-
return self._librispeech["audio"][self._idx]["array"]
21+
array = self._librispeech["audio"][self._idx]["array"]
22+
self._idx += 1
23+
return array
2224
except IndexError:
2325
if os.environ.get("IGNORE_DATASET_LIMITS") == "1":
2426
if self.reset():
@@ -28,9 +30,7 @@ def get_input_array(self):
2830
def submit_transcription(self, text: str):
2931
if self.do_skip():
3032
return
31-
3233
self._transcriptions.append(text)
33-
self._idx += 1
3434

3535
def reset(self):
3636
self._idx = 0
@@ -41,10 +41,11 @@ def _summarize_accuracy(self):
4141
if self.do_skip():
4242
return
4343

44-
assert len(self._transcriptions) == len(self._librispeech["text"][:self._idx])
44+
assert len(self._transcriptions) == self._idx
4545
wer_score = load("wer").compute(
4646
references=self._librispeech["text"][:self._idx], predictions=self._transcriptions
4747
)
48+
assert wer_score <= 1.0
4849
# print("\n WER score = {:.3f}".format(wer_score))
4950
# print(f"\n Accuracy figures above calculated on the basis of {self._idx} sample(s).")
5051
return {"wer_score": wer_score}

0 commit comments

Comments
 (0)