Skip to content

Commit afea561

Browse files
committed
[TTS/Torchhub] Expose HiFiGAN and FastPitch via TorchHub
1 parent 6a16011 commit afea561

File tree

6 files changed

+324
-2
lines changed

6 files changed

+324
-2
lines changed

PyTorch/SpeechSynthesis/HiFiGAN/common/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,6 @@
5151

5252
import matplotlib
5353

54-
matplotlib.use("Agg")
55-
import matplotlib.pylab as plt
5654
import numpy as np
5755
import torch
5856
import torch.distributed as dist
@@ -173,6 +171,8 @@ def print_once(*msg):
173171

174172

175173
def plot_spectrogram(spectrogram):
174+
matplotlib.use("Agg")
175+
import matplotlib.pylab as plt
176176
fig, ax = plt.subplots(figsize=(10, 2))
177177
im = ax.imshow(spectrogram, aspect="auto", origin="lower",
178178
interpolation='none')
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .entrypoints import nvidia_fastpitch, nvidia_textprocessing_utils
Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
# *****************************************************************************
2+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
3+
#
4+
# Redistribution and use in source and binary forms, with or without
5+
# modification, are permitted provided that the following conditions are met:
6+
# * Redistributions of source code must retain the above copyright
7+
# notice, this list of conditions and the following disclaimer.
8+
# * Redistributions in binary form must reproduce the above copyright
9+
# notice, this list of conditions and the following disclaimer in the
10+
# documentation and/or other materials provided with the distribution.
11+
# * Neither the name of the NVIDIA CORPORATION nor the
12+
# names of its contributors may be used to endorse or promote products
13+
# derived from this software without specific prior written permission.
14+
#
15+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
16+
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
17+
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18+
# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
19+
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20+
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21+
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
22+
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24+
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25+
#
26+
# *****************************************************************************
27+
28+
import urllib.request
29+
import torch
30+
import os
31+
import sys
32+
33+
#from https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/Tacotron2/inference.py
34+
def checkpoint_from_distributed(state_dict):
35+
"""
36+
Checks whether checkpoint was generated by DistributedDataParallel. DDP
37+
wraps model in additional "module.", it needs to be unwrapped for single
38+
GPU inference.
39+
:param state_dict: model's state dict
40+
"""
41+
ret = False
42+
for key, _ in state_dict.items():
43+
if key.find('module.') != -1:
44+
ret = True
45+
break
46+
return ret
47+
48+
49+
# from https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/Tacotron2/inference.py
50+
def unwrap_distributed(state_dict):
51+
"""
52+
Unwraps model from DistributedDataParallel.
53+
DDP wraps model in additional "module.", it needs to be removed for single
54+
GPU inference.
55+
:param state_dict: model's state dict
56+
"""
57+
new_state_dict = {}
58+
for key, value in state_dict.items():
59+
new_key = key.replace('module.1.', '')
60+
new_key = new_key.replace('module.', '')
61+
new_state_dict[new_key] = value
62+
return new_state_dict
63+
64+
def _download_checkpoint(checkpoint, force_reload):
65+
model_dir = os.path.join(torch.hub._get_torch_home(), 'checkpoints')
66+
if not os.path.exists(model_dir):
67+
os.makedirs(model_dir)
68+
ckpt_file = os.path.join(model_dir, os.path.basename(checkpoint))
69+
if not os.path.exists(ckpt_file) or force_reload:
70+
sys.stderr.write('Downloading checkpoint from {}\n'.format(checkpoint))
71+
urllib.request.urlretrieve(checkpoint, ckpt_file)
72+
return ckpt_file
73+
74+
75+
def nvidia_fastpitch(pretrained=True, **kwargs):
76+
"""TODO
77+
"""
78+
79+
from fastpitch import model as fastpitch
80+
81+
force_reload = "force_reload" in kwargs and kwargs["force_reload"]
82+
fp16 = "model_math" in kwargs and kwargs["model_math"] == "fp16"
83+
84+
if pretrained:
85+
checkpoint = 'https://api.ngc.nvidia.com/v2/models/nvidia/dle/fastpitch__pyt_ckpt/versions/21.12.1_amp/files/nvidia_fastpitch_210824+cfg.pt'
86+
ckpt_file = _download_checkpoint(checkpoint, force_reload)
87+
ckpt = torch.load(ckpt_file)
88+
state_dict = ckpt['state_dict']
89+
if checkpoint_from_distributed(state_dict):
90+
state_dict = unwrap_distributed(state_dict)
91+
config = ckpt['config']
92+
train_setup = ckpt.get('train_setup', {})
93+
else:
94+
config = {'n_mel_channels': 80, 'n_symbols': 148, 'padding_idx': 0, 'symbols_embedding_dim': 384,
95+
'in_fft_n_layers': 6, 'in_fft_n_heads': 1, 'in_fft_d_head': 64, 'in_fft_conv1d_kernel_size': 3,
96+
'in_fft_conv1d_filter_size': 1536, 'in_fft_output_size': 384, 'p_in_fft_dropout': 0.1,
97+
'p_in_fft_dropatt': 0.1, 'p_in_fft_dropemb': 0.0, 'out_fft_n_layers': 6, 'out_fft_n_heads': 1,
98+
'out_fft_d_head': 64, 'out_fft_conv1d_kernel_size': 3, 'out_fft_conv1d_filter_size': 1536,
99+
'out_fft_output_size': 384, 'p_out_fft_dropout': 0.1, 'p_out_fft_dropatt': 0.1, 'p_out_fft_dropemb': 0.0,
100+
'dur_predictor_kernel_size': 3, 'dur_predictor_filter_size': 256, 'p_dur_predictor_dropout': 0.1,
101+
'dur_predictor_n_layers': 2, 'pitch_predictor_kernel_size': 3, 'pitch_predictor_filter_size': 256,
102+
'p_pitch_predictor_dropout': 0.1, 'pitch_predictor_n_layers': 2, 'pitch_embedding_kernel_size': 3,
103+
'n_speakers': 1, 'speaker_emb_weight': 1.0, 'energy_predictor_kernel_size': 3,
104+
'energy_predictor_filter_size': 256, 'p_energy_predictor_dropout': 0.1, 'energy_predictor_n_layers': 2,
105+
'energy_conditioning': True, 'energy_embedding_kernel_size': 3}
106+
for k,v in kwargs.items():
107+
if k in config.keys():
108+
config[k] = v
109+
train_setup = {}
110+
111+
model = fastpitch.FastPitch(**config)
112+
113+
if pretrained:
114+
model.load_state_dict(state_dict)
115+
116+
if fp16:
117+
model.half()
118+
119+
model.forward = model.infer
120+
121+
return model, train_setup
122+
123+
124+
def nvidia_textprocessing_utils(cmudict_path, heteronyms_path, **kwargs):
125+
126+
from common.text.text_processing import TextProcessing
127+
import numpy as np
128+
from torch.nn.utils.rnn import pad_sequence
129+
from common.text import cmudict
130+
131+
132+
class TextPreProcessing:
133+
@staticmethod
134+
def prepare_input_sequence(texts, batch_size=1, device='cpu'):
135+
cmudict.initialize(cmudict_path, heteronyms_path)
136+
tp = TextProcessing(symbol_set='english_basic', cleaner_names=['english_cleaners_v2'], p_arpabet=1.0)
137+
fields={}
138+
139+
fields['text'] = [torch.LongTensor(tp.encode_text(text))
140+
for text in texts]
141+
order = np.argsort([-t.size(0) for t in fields['text']])
142+
143+
fields['text'] = [fields['text'][i] for i in order]
144+
fields['text_lens'] = torch.LongTensor([t.size(0) for t in fields['text']])
145+
146+
for t in fields['text']:
147+
print(tp.sequence_to_text(t.numpy()))
148+
149+
# cut into batches & pad
150+
batches = []
151+
for b in range(0, len(order), batch_size):
152+
batch = {f: values[b:b+batch_size] for f, values in fields.items()}
153+
for f in batch:
154+
if f == 'text':
155+
batch[f] = pad_sequence(batch[f], batch_first=True)
156+
157+
if type(batch[f]) is torch.Tensor:
158+
batch[f] = batch[f].to(device)
159+
batches.append(batch)
160+
161+
return batches
162+
163+
return TextPreProcessing()
164+
165+
166+
167+
# # from tacotron2.text import text_to_sequence
168+
169+
# @staticmethod
170+
# def pad_sequences(batch):
171+
# # Right zero-pad all one-hot text sequences to max input length
172+
# input_lengths, ids_sorted_decreasing = torch.sort(
173+
# torch.LongTensor([len(x) for x in batch]),
174+
# dim=0, descending=True)
175+
# max_input_len = input_lengths[0]
176+
177+
# text_padded = torch.LongTensor(len(batch), max_input_len)
178+
# text_padded.zero_()
179+
# for i in range(len(ids_sorted_decreasing)):
180+
# text = batch[ids_sorted_decreasing[i]]
181+
# text_padded[i, :text.size(0)] = text
182+
183+
# return text_padded, input_lengths
184+
185+
# @staticmethod
186+
# def prepare_input_sequence(texts, cpu_run=False):
187+
188+
# d = []
189+
# # for i,text in enumerate(texts):
190+
# # d.append(torch.IntTensor(
191+
# # Processing.text_to_sequence(text, ['english_cleaners'])[:]))
192+
193+
# text_padded, input_lengths = Processing.pad_sequences(d)
194+
# if not cpu_run:
195+
# text_padded = text_padded.cuda().long()
196+
# input_lengths = input_lengths.cuda().long()
197+
# else:
198+
# text_padded = text_padded.long()
199+
# input_lengths = input_lengths.long()
200+
201+
# return text_padded, input_lengths
202+
203+
# return Processing()
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .entrypoints import nvidia_hifigan
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# *****************************************************************************
2+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
3+
#
4+
# Redistribution and use in source and binary forms, with or without
5+
# modification, are permitted provided that the following conditions are met:
6+
# * Redistributions of source code must retain the above copyright
7+
# notice, this list of conditions and the following disclaimer.
8+
# * Redistributions in binary form must reproduce the above copyright
9+
# notice, this list of conditions and the following disclaimer in the
10+
# documentation and/or other materials provided with the distribution.
11+
# * Neither the name of the NVIDIA CORPORATION nor the
12+
# names of its contributors may be used to endorse or promote products
13+
# derived from this software without specific prior written permission.
14+
#
15+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
16+
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
17+
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18+
# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
19+
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20+
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21+
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
22+
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24+
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25+
#
26+
# *****************************************************************************
27+
28+
import urllib.request
29+
import torch
30+
import os
31+
import sys
32+
33+
#from https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/Tacotron2/inference.py
34+
def checkpoint_from_distributed(state_dict):
35+
"""
36+
Checks whether checkpoint was generated by DistributedDataParallel. DDP
37+
wraps model in additional "module.", it needs to be unwrapped for single
38+
GPU inference.
39+
:param state_dict: model's state dict
40+
"""
41+
ret = False
42+
for key, _ in state_dict.items():
43+
if key.find('module.') != -1:
44+
ret = True
45+
break
46+
return ret
47+
48+
49+
# from https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/Tacotron2/inference.py
50+
def unwrap_distributed(state_dict):
51+
"""
52+
Unwraps model from DistributedDataParallel.
53+
DDP wraps model in additional "module.", it needs to be removed for single
54+
GPU inference.
55+
:param state_dict: model's state dict
56+
"""
57+
new_state_dict = {}
58+
for key, value in state_dict.items():
59+
new_key = key.replace('module.1.', '')
60+
new_key = new_key.replace('module.', '')
61+
new_state_dict[new_key] = value
62+
return new_state_dict
63+
64+
def _download_checkpoint(checkpoint, force_reload):
65+
model_dir = os.path.join(torch.hub._get_torch_home(), 'checkpoints')
66+
if not os.path.exists(model_dir):
67+
os.makedirs(model_dir)
68+
ckpt_file = os.path.join(model_dir, os.path.basename(checkpoint))
69+
if not os.path.exists(ckpt_file) or force_reload:
70+
sys.stderr.write('Downloading checkpoint from {}\n'.format(checkpoint))
71+
urllib.request.urlretrieve(checkpoint, ckpt_file)
72+
return ckpt_file
73+
74+
75+
def nvidia_hifigan(pretrained=True, **kwargs):
76+
"""TODO
77+
"""
78+
from hifigan import models as vocoder
79+
80+
force_reload = "force_reload" in kwargs and kwargs["force_reload"]
81+
fp16 = "model_math" in kwargs and kwargs["model_math"] == "fp16"
82+
83+
if pretrained:
84+
checkpoint = 'https://api.ngc.nvidia.com/v2/models/nvidia/dle/hifigan__pyt_ckpt_mode-finetune_ds-ljs22khz/versions/21.08.0_amp/files/hifigan_gen_checkpoint_10000_ft.pt'
85+
ckpt_file = _download_checkpoint(checkpoint, force_reload)
86+
ckpt = torch.load(ckpt_file)
87+
state_dict = ckpt['generator']
88+
if checkpoint_from_distributed(state_dict):
89+
state_dict = unwrap_distributed(state_dict)
90+
config = ckpt['config']
91+
train_setup = ckpt.get('train_setup', {})
92+
else:
93+
config = {'upsample_rates': [8, 8, 2, 2], 'upsample_kernel_sizes': [16, 16, 4, 4],
94+
'upsample_initial_channel': 512, 'resblock': '1', 'resblock_kernel_sizes': [3, 7, 11],
95+
'resblock_dilation_sizes': [[1, 3, 5], [1, 3, 5], [1, 3, 5]]}
96+
for k,v in kwargs.items():
97+
if k in config.keys():
98+
config[k] = v
99+
train_setup = {}
100+
101+
hifigan = vocoder.Generator(config)
102+
denoiser = None
103+
if pretrained:
104+
hifigan.load_state_dict(state_dict)
105+
hifigan.remove_weight_norm()
106+
denoiser = vocoder.Denoiser(hifigan, win_length=1024)
107+
108+
if fp16:
109+
hifigan.half()
110+
denoiser.half()
111+
112+
return hifigan, train_setup, denoiser

hubconf.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,10 @@
2525
from PyTorch.SpeechSynthesis.Tacotron2.waveglow import nvidia_waveglow
2626
sys.path.append(os.path.join(sys.path[0], 'PyTorch/SpeechSynthesis/Tacotron2'))
2727

28+
from PyTorch.SpeechSynthesis.HiFiGAN.fastpitch import nvidia_fastpitch
29+
from PyTorch.SpeechSynthesis.HiFiGAN.fastpitch import nvidia_textprocessing_utils
30+
from PyTorch.SpeechSynthesis.HiFiGAN.hifigan import nvidia_hifigan
31+
sys.path.append(os.path.join(sys.path[0], 'PyTorch/SpeechSynthesis/HiFiGAN'))
32+
2833
from PyTorch.Forecasting.TFT.tft_torchhub import nvidia_tft, nvidia_tft_data_utils
2934
sys.path.append(os.path.join(sys.path[0], 'PyTorch/Forecasting/TFT'))

0 commit comments

Comments
 (0)