Skip to content

Commit ae8c828

Browse files
authored
add whisper translate (#235)
1 parent e1581d0 commit ae8c828

File tree

6 files changed

+172
-0
lines changed

6 files changed

+172
-0
lines changed

.github/workflows/test.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ jobs:
119119
S3_URL_IMAGENET_DATASET_LABELS: ${{ secrets.S3_URL_IMAGENET_DATASET_LABELS }}
120120
S3_URL_COCO_DATASET: ${{ secrets.S3_URL_COCO_DATASET }}
121121
S3_URL_COCO_DATASET_ANNOTATIONS: ${{ secrets.S3_URL_COCO_DATASET_ANNOTATIONS }}
122+
S3_URL_COVOST2_DATASET: ${{ secrets.S3_URL_COVOST2_DATASET }}
122123
HF_HUB_TOKEN: ${{ secrets.HF_HUB_TOKEN }}
123124
steps:
124125
- name: Install git
@@ -179,6 +180,7 @@ jobs:
179180
S3_URL_IMAGENET_DATASET_LABELS: ${{ secrets.S3_URL_IMAGENET_DATASET_LABELS }}
180181
S3_URL_COCO_DATASET: ${{ secrets.S3_URL_COCO_DATASET }}
181182
S3_URL_COCO_DATASET_ANNOTATIONS: ${{ secrets.S3_URL_COCO_DATASET_ANNOTATIONS }}
183+
S3_URL_COVOST2_DATASET: ${{ secrets.S3_URL_COVOST2_DATASET }}
182184
HF_HUB_TOKEN: ${{ secrets.HF_HUB_TOKEN }}
183185
steps:
184186
- name: Install Ampere optimized PyTorch
@@ -221,6 +223,7 @@ jobs:
221223
S3_URL_IMAGENET_DATASET_LABELS: ${{ secrets.S3_URL_IMAGENET_DATASET_LABELS }}
222224
S3_URL_COCO_DATASET: ${{ secrets.S3_URL_COCO_DATASET }}
223225
S3_URL_COCO_DATASET_ANNOTATIONS: ${{ secrets.S3_URL_COCO_DATASET_ANNOTATIONS }}
226+
S3_URL_COVOST2_DATASET: ${{ secrets.S3_URL_COVOST2_DATASET }}
224227
HF_HUB_TOKEN: ${{ secrets.HF_HUB_TOKEN }}
225228
steps:
226229
- name: Git checkout & pull submodules

setup_deb.sh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,12 @@ pip3 install --no-deps --upgrade \
174174
streamlit-drawable-canvas==0.8.0 \
175175
safetensors>=0.3.1
176176

177+
apt install -y autoconf autogen automake build-essential libasound2-dev \
178+
libflac-dev libogg-dev libtool libvorbis-dev libopus-dev libmp3lame-dev \
179+
libmpg123-dev pkg-config
180+
apt remove -y libsndfile1
181+
git clone https://github.com/libsndfile/libsndfile.git && cd libsndfile/ && autoreconf -vif && ./configure --enable-werror && make -j && make install && ldconfig && cd .. && rm -rf libsndfile
182+
177183
if [ "$(PYTHONPATH=$SCRIPT_DIR python3 -c 'from cpuinfo import get_cpu_info; from benchmark import which_ampere_cpu; cpu = which_ampere_cpu(get_cpu_info()["flags"], 1); print("AmpereOne" in cpu)')" == "True" ]; then
178184
# Only on AmpereOne family
179185
pip3 install --upgrade --no-deps \
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Whisper translate
2+
3+
This folder contains the script to run Whisper on speech-to-text translation task in PyTorch framework.
4+
5+
The original paper on the architecture is available here: https://arxiv.org/pdf/2212.04356.pdf
6+
7+
8+
### Dataset
9+
10+
Download the Common Voice Corpus for the Japanese language here: https://commonvoice.mozilla.org/en/datasets
11+
12+
Extract the dataset:
13+
```
14+
tar -xvf ja.tar
15+
```
16+
17+
### Running instructions
18+
19+
Before running any code you should first export the PYTHONPATH variable with path pointing to the Ampere Model Library directory,
20+
as well as AIO_NUM_THREADS specifying the number of threads to be used.
21+
22+
```
23+
export PYTHONPATH=/path/to/ampere_model_library
24+
export AIO_NUM_THREADS=1
25+
```
26+
27+
For the best experience we also recommend setting environment variables as specified below.
28+
29+
```
30+
export COMMONVOICE_PATH=/path/to/dataset
31+
```
32+
33+
Now you are able to run the run.py script.
34+
35+
To get detailed information on the script's recognized arguments run it with -h flag for help.
36+
37+
For PyTorch implementation the size of the model (with a flag "-m") has to be specified.
38+
39+
Example command for PyTorch:
40+
41+
```
42+
python3 run.py -m medium --timeout 600
43+
```
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# Copyright (c) 2024, Ampere Computing LLC
3+
import os
4+
import sys
5+
import torch
6+
7+
8+
def run_pytorch_fp32(model_name, num_runs, timeout, dataset_path, **kwargs):
9+
batch_size = 1
10+
sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "whisper"))
11+
from utils.benchmark import run_model
12+
from utils.misc import print_warning_message
13+
from utils.pytorch import PyTorchRunnerV2
14+
from utils.speech_recognition.covost2 import Covost2
15+
from speech_recognition.whisper.whisper.whisper import load_model
16+
from speech_recognition.whisper.whisper.whisper.transcribe import transcribe
17+
model = load_model(model_name)
18+
model.eval()
19+
20+
def single_pass_pytorch(_runner, _covost2):
21+
array = _covost2.get_input_array()
22+
audio = torch.from_numpy(array.astype("float32"))
23+
_covost2.submit_translation(
24+
_runner.run(batch_size * array.shape[0], audio)["text"].lstrip().replace(".", "")
25+
)
26+
27+
def translate_wrapper(audio):
28+
return transcribe(model, audio, verbose=None, task="translate", language="ja")
29+
30+
runner = PyTorchRunnerV2(translate_wrapper, throughput_only=True)
31+
librispeech = Covost2(dataset_path)
32+
print_warning_message("Sampling rate Whisper operates at is 16,000 Hz, therefore throughput values below can be "
33+
"divided by 16,000 to derive 'seconds of processed audio per second'")
34+
return run_model(single_pass_pytorch, runner, librispeech, batch_size, num_runs, timeout)
35+
36+
37+
if __name__ == "__main__":
38+
from utils.helpers import DefaultArgParser
39+
whisper_variants = ["tiny", "base", "small", "medium", "large"]
40+
parser = DefaultArgParser(["pytorch"])
41+
parser.require_model_name(whisper_variants)
42+
parser.add_argument("--dataset_path", type=str, required=True,
43+
help="path to the CommonVoice Japanese dataset directory")
44+
run_pytorch_fp32(**vars(parser.parse()))

tests/test_pytorch_models.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,37 @@ def test_whisper_large(self):
114114
self.assertTrue(wer_ref / acc["wer_score"] > 0.95)
115115

116116

117+
class WhisperTranslate(unittest.TestCase):
118+
def setUp(self):
119+
from speech_recognition.whisper_translate.run import run_pytorch_fp32
120+
121+
self.dataset_path = pathlib.Path(get_downloads_path(), "covost2_ja")
122+
if not self.dataset_path.exists():
123+
url = os.environ.get("S3_URL_COVOST2_DATASET")
124+
assert url is not None
125+
subprocess.run(f"mkdir {self.dataset_path}".split(),
126+
check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
127+
subprocess.run(f"wget -P /tmp {url}".split(),
128+
check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
129+
subprocess.run(f"tar -xf /tmp/covost2_ja.tar -C {self.dataset_path}".split(),
130+
check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
131+
subprocess.run("rm /tmp/covost2_ja.tar".split(),
132+
check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
133+
134+
def wrapper(**kwargs):
135+
kwargs["q"].put(run_pytorch_fp32(**kwargs)[0])
136+
137+
self.wrapper = wrapper
138+
139+
@unittest.skipIf(psutil.virtual_memory().available / 1024 ** 3 < 100, "too little memory")
140+
@unittest.skipUnless('_aio_profiler_print' in dir(torch._C), "too slow to run with native")
141+
def test_whisper_translate_medium(self):
142+
wer_ref = 0.475
143+
acc = run_process(self.wrapper, {"model_name": "large", "num_runs": 30, "timeout": None,
144+
"dataset_path": self.dataset_path})
145+
self.assertTrue(wer_ref / acc["bleu_score"] > 0.95)
146+
147+
117148
class DLRM(unittest.TestCase):
118149
def setUp(self):
119150
self.dataset_path = pathlib.Path(get_downloads_path(), "criteo_preprocessed")
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# Copyright (c) 2024, Ampere Computing LLC
3+
from evaluate import load
4+
from datasets import load_dataset
5+
import utils.misc as utils
6+
from utils.misc import OutOfInstances
7+
from utils.helpers import Dataset
8+
9+
10+
class Covost2(Dataset):
11+
sampling_rate = 16000
12+
13+
def __init__(self, dataset_path=None):
14+
15+
if dataset_path is None:
16+
env_var = "COMMONVOICE_PATH"
17+
dataset_path = utils.get_env_variable(
18+
env_var, f"Path to CommonVoice directory has not been specified with {env_var} flag")
19+
20+
self._covost2 = load_dataset("covost2", "ja_en", split="validation", data_dir=dataset_path)
21+
self.available_instances = len(self._covost2["audio"])
22+
self._idx = 0
23+
self._translations = []
24+
25+
def get_input_array(self):
26+
try:
27+
return self._covost2["audio"][self._idx]["array"]
28+
except IndexError:
29+
raise OutOfInstances
30+
31+
def submit_translation(self, text: str):
32+
self._translations.append(text)
33+
self._idx += 1
34+
35+
def reset(self):
36+
self._idx = 0
37+
self._translations = []
38+
return True
39+
40+
def _summarize_accuracy(self):
41+
assert len(self._translations) == len(self._covost2["translation"][:self._idx])
42+
bleu_score = load("bleu").compute(
43+
references=self._covost2["translation"][:self._idx], predictions=self._translations
44+
)
45+
return {"bleu_score": bleu_score["bleu"]}

0 commit comments

Comments
 (0)