Skip to content
This repository was archived by the owner on Oct 25, 2024. It is now read-only.

Commit 50676da

Browse files
Add shanghainese ASR TTS finetuning and inference example (#1588)
* add shanghainese ASR TTS finetuning and inference example --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 2f7feb7 commit 50676da

File tree

24 files changed

+2900
-0
lines changed

24 files changed

+2900
-0
lines changed
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
2+
# Shanghainese ASR (Audio-Speech-Recognition) and TTS (Text-To-Speech) finetuning/inference
3+
4+
This example introduces how to do Shanghainese audio-to-text and text-to-audio conversion.
5+
6+
7+
## Related models
8+
9+
### ASR
10+
11+
* Conversion from the Shanghainese audio to the Shanghainese text
12+
Finetuned [jonatasgrosman/wav2vec2-large-xlsr-53-chinese-zh-cn](https://huggingface.co/jonatasgrosman/wav2vec2-large-xlsr-53-chinese-zh-cn)
13+
14+
* Conversion from Shanghainese text to Mandarin text
15+
Finetuned [Helsinki-NLP/opus-mt-zh-en](https://huggingface.co/Helsinki-NLP/opus-mt-zh-en)
16+
17+
18+
### TTS
19+
20+
* Conversion from Mandarin text to Shanghainese text
21+
Finetuned [Helsinki-NLP/opus-mt-en-zh](https://huggingface.co/Helsinki-NLP/opus-mt-en-zh)
22+
23+
* Conversion from Shanghainese text to Shanghainese audio
24+
Finetuned [VITS](https://github.com/jaywalnut310/vits)
25+
26+
## Prepare Environment
27+
28+
## 1. Install requirements
29+
30+
```sh
31+
pip install -r requirements.txt
32+
# force to install torch 2.0+
33+
pip install torch==2.1.0
34+
```
35+
36+
<!-- ## 2. Build monotonic alignment search
37+
38+
```py
39+
cd monotonic_align
40+
mkdir monotonic_align
41+
python setup.py build_ext --inplace
42+
cd ..
43+
``` -->
44+
45+
## Finetuning
46+
47+
### ASR
48+
49+
#### 1. Prepare the data
50+
51+
Please check this [repo](https://github.com/Cosmos-Break/asr) and download the folders named like `Shanghai_*` to the current directory.
52+
53+
#### 2. Do finetuning of the Shanghainese Audio -> Shanghainese text ASR model
54+
55+
```py
56+
python train_asr.py
57+
```
58+
59+
#### 3. Do finetuning of the Shanghainese text -> Mandarian text translation model
60+
61+
```py
62+
python train_translation.py
63+
```
64+
65+
### TTS
66+
67+
#### 1. Download the pre-finetuned VITS model
68+
69+
Download [model](https://sjtueducn-my.sharepoint.com/:u:/g/personal/cjang_cjengh_sjtu_edu_cn/EfnEO6kW-CNNhywJmIZNPU0BUmFdSArguFETp0pjtvHZBA?e=dKJULk) (2796 epochs)
70+
71+
Put the `model.pth` model under `model/`.
72+
73+
74+
#### 3. Finetune the Mandarian text -> Shanghainese text translation model
75+
76+
```py
77+
python train_translation_reverse.py
78+
```
79+
80+
## Inference
81+
82+
### ASR
83+
84+
#### Do inference of the Shanghainese Audio -> Shanghainese text ASR model
85+
86+
87+
```
88+
python inference_asr.py
89+
```
90+
91+
#### Do inference of the Shanghainese text -> Mandarian text translation model
92+
93+
```
94+
python inference_translation.py
95+
```
96+
97+
### TTS
98+
99+
100+
#### Do inference of the Mandarian text -> Shanghainese text translation model
101+
```
102+
python inference_translation_reverse.py
103+
```
104+
105+
#### Do inference of the Shanghainese text -> Shanghainese audio TTS model
106+
107+
```
108+
python inference_tts.py
109+
```
110+
111+
## Demo
112+
113+
```sh
114+
export no_proxy="localhost,127.0.0.1"
115+
nohup python -u app.py &
116+
```
117+
118+
![asr-sh](https://imgur.com/dAB4vxj.png)
119+
120+
![tts-sh](https://imgur.com/0i0xcVH.png)
121+
122+
## Acknowledgements
123+
124+
The code is adapted from [Cosmos-Break/asr](https://github.com/Cosmos-Break/asr) and [CjangCjengh/vits](https://github.com/CjangCjengh/vits). We thanks the authors for their great work!
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
# !/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
#
4+
# Copyright (c) 2024 Intel Corporation
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
18+
import torch
19+
20+
import gradio as gr
21+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
22+
import tempfile
23+
import os
24+
from huggingsound import SpeechRecognitionModel
25+
import commons
26+
27+
import soundfile as sf
28+
import utils
29+
from models import SynthesizerTrn
30+
from text import text_to_sequence
31+
32+
"""Usage:
33+
export no_proxy="localhost,127.0.0.1"
34+
nohup python -u app.py &
35+
"""
36+
37+
38+
ASR_MODEL_PATH = "spycsh/shanghainese-wav2vec-3800"
39+
TRANSLATE_MODEL_PATH = "spycsh/shanghainese-opus-sh-zh-3500"
40+
41+
REVERSE_MODEL_NAME = "spycsh/shanghainese-opus-zh-sh-4000"
42+
43+
device = "cuda" if torch.cuda.is_available() else "cpu"
44+
batch_size = 1
45+
46+
asr_model = SpeechRecognitionModel(ASR_MODEL_PATH, device=device)
47+
48+
translate_tokenizer = AutoTokenizer.from_pretrained(TRANSLATE_MODEL_PATH)
49+
translate_model = AutoModelForSeq2SeqLM.from_pretrained(TRANSLATE_MODEL_PATH).to(device)
50+
51+
52+
reverse_translate_tokenizer = AutoTokenizer.from_pretrained(REVERSE_MODEL_NAME)
53+
reverse_translate_model = AutoModelForSeq2SeqLM.from_pretrained(REVERSE_MODEL_NAME).to(device)
54+
55+
hps = utils.get_hparams_from_file("model/config.json")
56+
n_speakers = hps.data.n_speakers
57+
net_g = SynthesizerTrn(
58+
len(hps.symbols),
59+
hps.data.filter_length // 2 + 1,
60+
hps.train.segment_size // hps.data.hop_length,
61+
n_speakers=n_speakers, #####
62+
**hps.model)
63+
_ = net_g.eval()
64+
net_g = net_g.to(device)
65+
66+
_ = utils.load_checkpoint("model/model.pth", net_g)
67+
68+
69+
demo = gr.Blocks()
70+
71+
def generate_translation(model, tokenizer, example):
72+
"""Print out the source, target and predicted raw text."""
73+
74+
input_ids = example['input_ids']
75+
input_ids = torch.LongTensor(input_ids).view(1, -1).to(model.device)
76+
# print('input_ids: ', input_ids)
77+
generated_ids = model.generate(input_ids, max_new_tokens=64)
78+
# print('generated_ids: ', generated_ids)
79+
prediction = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
80+
81+
print('prediction: ', prediction)
82+
return prediction
83+
84+
def transcribe(inputs, translate=False):
85+
if inputs is None:
86+
raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.")
87+
print(inputs)
88+
# (44100, array([ 0, 0, 0, ..., -60, 50, -37], dtype=int16))
89+
sr, waveform = inputs
90+
sf.write("test.wav", waveform, sr, format="wav")
91+
if not translate:
92+
return asr_model.transcribe(["test.wav"])[0]['transcription']
93+
else:
94+
txt = asr_model.transcribe(["test.wav"])[0]['transcription']
95+
with translate_tokenizer.as_target_tokenizer():
96+
model_inputs = translate_tokenizer(txt, max_length=64, truncation=True)
97+
example = {}
98+
example['sh'] = txt
99+
example['zh'] = txt
100+
example['input_ids'] = model_inputs['input_ids']
101+
print(txt)
102+
print(example)
103+
return generate_translation(translate_model, translate_tokenizer, example)
104+
105+
106+
107+
108+
translate=gr.Checkbox(label='Translate into Mandarin')
109+
110+
asr_tab = gr.Interface(
111+
fn=transcribe,
112+
inputs= [
113+
gr.Audio(sources=["microphone", "upload"],
114+
waveform_options=gr.WaveformOptions(
115+
waveform_color="#01C6FF",
116+
waveform_progress_color="#0066B4",
117+
skip_length=2,
118+
show_controls=False,
119+
)
120+
),
121+
translate
122+
],
123+
124+
outputs="text",
125+
126+
title="Shanghainese ASR",
127+
description=(
128+
"Transcribe Mandarin long-form microphone or audio inputs to Shanghainese with the click of a button!"
129+
),
130+
allow_flagging="never",
131+
)
132+
133+
def get_text(text, hps):
134+
text_norm = text_to_sequence(text, hps.symbols, hps.data.text_cleaners)
135+
if hps.data.add_blank:
136+
text_norm = commons.intersperse(text_norm, 0)
137+
text_norm = torch.LongTensor(text_norm)
138+
return text_norm
139+
140+
def t2s(inputs, reverse_translate=False):
141+
if inputs is None:
142+
raise gr.Error("No input text found! Please check the input text!")
143+
print(inputs) # inputs: text
144+
text = inputs
145+
if reverse_translate:
146+
model_inputs = reverse_translate_tokenizer(inputs,max_length=64, truncation=True)
147+
example = {}
148+
example['sh'] = text
149+
example['zh'] = text
150+
example['input_ids'] = model_inputs['input_ids']
151+
text = generate_translation(reverse_translate_model, reverse_translate_tokenizer, example)
152+
print(text)
153+
154+
stn_tst = get_text(text, hps)
155+
with torch.no_grad():
156+
x_tst = stn_tst.unsqueeze(0).to(device)
157+
x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(device)
158+
sid = torch.LongTensor([0]).to(device)
159+
print(x_tst, x_tst_lengths, sid)
160+
audio = net_g.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=.667, noise_scale_w=0.8, length_scale=1)[0][0,0].data.cpu().float().numpy()
161+
print(audio)
162+
return (hps.data.sampling_rate, audio)
163+
164+
reverse_translate=gr.Checkbox(value=False, label='Mandarian as input text')
165+
166+
167+
tts_tab = gr.Interface(
168+
fn=t2s,
169+
inputs=[
170+
gr.Textbox(label="input text", value="请侬让只位子,拨需要帮助个乘客,谢谢侬。"),
171+
reverse_translate
172+
],
173+
outputs="audio",
174+
175+
title="Shanghainese TTS",
176+
description=(
177+
"Shanghainese Text To Speech with one click!"
178+
),
179+
allow_flagging="never",
180+
)
181+
182+
with demo:
183+
gr.TabbedInterface([asr_tab, tts_tab], ["SH-ASR", "SH-TTS"])
184+
185+
demo.launch(share=True)

0 commit comments

Comments
 (0)