Skip to content

Commit 6a16011

Browse files
committed
Merge: [WaveGlow/PyT] Enable TorchScript
2 parents 8cdaba1 + 0e1c6c5 commit 6a16011

File tree

3 files changed

+84
-42
lines changed

3 files changed

+84
-42
lines changed

PyTorch/SpeechSynthesis/Tacotron2/inference.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -106,13 +106,15 @@ def unwrap_distributed(state_dict):
106106
return new_state_dict
107107

108108

109-
def load_and_setup_model(model_name, parser, checkpoint, fp16_run, cpu_run, forward_is_infer=False):
109+
def load_and_setup_model(model_name, parser, checkpoint, fp16_run, cpu_run,
110+
forward_is_infer=False, jittable=False):
110111
model_parser = models.model_parser(model_name, parser, add_help=False)
111112
model_args, _ = model_parser.parse_known_args()
112113

113114
model_config = models.get_model_config(model_name, model_args)
114115
model = models.get_model(model_name, model_config, cpu_run=cpu_run,
115-
forward_is_infer=forward_is_infer)
116+
forward_is_infer=forward_is_infer,
117+
jittable=jittable)
116118

117119
if checkpoint is not None:
118120
if cpu_run:
@@ -207,11 +209,14 @@ def main():
207209
tacotron2 = load_and_setup_model('Tacotron2', parser, args.tacotron2,
208210
args.fp16, args.cpu, forward_is_infer=True)
209211
waveglow = load_and_setup_model('WaveGlow', parser, args.waveglow,
210-
args.fp16, args.cpu, forward_is_infer=True)
212+
args.fp16, args.cpu, forward_is_infer=True,
213+
jittable=True)
211214
denoiser = Denoiser(waveglow)
212215
if not args.cpu:
213216
denoiser.cuda()
214217

218+
waveglow.make_ts_scriptable()
219+
jitted_waveglow = torch.jit.script(waveglow)
215220
jitted_tacotron2 = torch.jit.script(tacotron2)
216221

217222
texts = []
@@ -231,7 +236,7 @@ def main():
231236
for i in range(3):
232237
with torch.no_grad():
233238
mel, mel_lengths, _ = jitted_tacotron2(sequence, input_lengths)
234-
_ = waveglow(mel)
239+
_ = jitted_waveglow(mel)
235240

236241
measurements = {}
237242

@@ -241,7 +246,7 @@ def main():
241246
mel, mel_lengths, alignments = jitted_tacotron2(sequences_padded, input_lengths)
242247

243248
with torch.no_grad(), MeasureTime(measurements, "waveglow_time", args.cpu):
244-
audios = waveglow(mel, sigma=args.sigma_infer)
249+
audios = jitted_waveglow(mel, sigma=args.sigma_infer)
245250
audios = audios.float()
246251
with torch.no_grad(), MeasureTime(measurements, "denoiser_time", args.cpu):
247252
audios = denoiser(audios, strength=args.denoising_strength).squeeze(1)

PyTorch/SpeechSynthesis/Tacotron2/models.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ def init_bn(module):
6363

6464

6565
def get_model(model_name, model_config, cpu_run,
66-
uniform_initialize_bn_weight=False, forward_is_infer=False):
66+
uniform_initialize_bn_weight=False, forward_is_infer=False,
67+
jittable=False):
6768
""" Code chooses a model based on name"""
6869
model = None
6970
if model_name == 'Tacotron2':
@@ -75,13 +76,11 @@ def forward(self, inputs, input_lengths):
7576
else:
7677
model = Tacotron2(**model_config)
7778
elif model_name == 'WaveGlow':
79+
80+
model = WaveGlow(**model_config)
7881
if forward_is_infer:
79-
class WaveGlow__forward_is_infer(WaveGlow):
80-
def forward(self, spect, sigma=1.0):
81-
return self.infer(spect, sigma)
82-
model = WaveGlow__forward_is_infer(**model_config)
83-
else:
84-
model = WaveGlow(**model_config)
82+
model.forward = model.infer
83+
8584
else:
8685
raise NotImplementedError(model_name)
8786

PyTorch/SpeechSynthesis/Tacotron2/waveglow/model.py

Lines changed: 68 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,14 @@
2626
# *****************************************************************************
2727
import torch
2828
torch._C._jit_set_autocast_mode(False)
29-
from torch.autograd import Variable
29+
import torch.nn as nn
3030
import torch.nn.functional as F
31+
from torch.autograd import Variable
3132

3233

3334
@torch.jit.script
34-
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
35-
n_channels_int = n_channels[0]
35+
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels : int):
36+
n_channels_int = n_channels
3637
in_act = input_a + input_b
3738
t_act = torch.tanh(in_act[:, :n_channels_int, :])
3839
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
@@ -73,22 +74,14 @@ def forward(self, z):
7374
z = self.conv(z)
7475
return z, log_det_W
7576

76-
7777
def infer(self, z):
78-
# shape
79-
batch_size, group_size, n_of_groups = z.size()
80-
81-
W = self.conv.weight.squeeze()
78+
self._invert()
79+
return F.conv1d(z, self.W_inverse, bias=None, stride=1, padding=0)
8280

81+
def _invert(self):
8382
if not hasattr(self, 'W_inverse'):
84-
# Reverse computation
85-
W_inverse = W.float().inverse()
86-
W_inverse = Variable(W_inverse[..., None])
87-
if z.type() == 'torch.cuda.HalfTensor' or z.type() == 'torch.HalfTensor':
88-
W_inverse = W_inverse.half()
89-
self.W_inverse = W_inverse
90-
z = F.conv1d(z, self.W_inverse, bias=None, stride=1, padding=0)
91-
return z
83+
W = self.conv.weight.squeeze()
84+
self.W_inverse = W.float().inverse().unsqueeze(-1).to(W.dtype)
9285

9386

9487
class WN(torch.nn.Module):
@@ -142,27 +135,25 @@ def __init__(self, n_in_channels, n_mel_channels, n_layers, n_channels,
142135
res_skip_layer, name='weight')
143136
self.res_skip_layers.append(res_skip_layer)
144137

145-
def forward(self, forward_input):
146-
audio, spect = forward_input
138+
def forward(self, audio, spect):
147139
audio = self.start(audio)
148140

149-
for i in range(self.n_layers):
141+
output = 0
142+
for i, (in_layer, cond_layer, res_skip_layer) in enumerate(
143+
zip(self.in_layers, self.cond_layers, self.res_skip_layers)):
150144
acts = fused_add_tanh_sigmoid_multiply(
151-
self.in_layers[i](audio),
152-
self.cond_layers[i](spect),
153-
torch.IntTensor([self.n_channels]))
145+
in_layer(audio),
146+
cond_layer(spect),
147+
self.n_channels)
154148

155-
res_skip_acts = self.res_skip_layers[i](acts)
149+
res_skip_acts = res_skip_layer(acts)
156150
if i < self.n_layers - 1:
157151
audio = res_skip_acts[:, :self.n_channels, :] + audio
158152
skip_acts = res_skip_acts[:, self.n_channels:, :]
159153
else:
160154
skip_acts = res_skip_acts
161155

162-
if i == 0:
163-
output = skip_acts
164-
else:
165-
output = skip_acts + output
156+
output += skip_acts
166157
return self.end(output)
167158

168159

@@ -229,7 +220,7 @@ def forward(self, forward_input):
229220
audio_0 = audio[:, :n_half, :]
230221
audio_1 = audio[:, n_half:, :]
231222

232-
output = self.WN[k]((audio_0, spect))
223+
output = self.WN[k](audio_0, spect)
233224
log_s = output[:, n_half:, :]
234225
b = output[:, :n_half, :]
235226
audio_1 = torch.exp(log_s) * audio_1 + b
@@ -262,7 +253,7 @@ def infer(self, spect, sigma=1.0):
262253
audio_0 = audio[:, :n_half, :]
263254
audio_1 = audio[:, n_half:, :]
264255

265-
output = self.WN[k]((audio_0, spect))
256+
output = self.WN[k](audio_0, spect)
266257
s = output[:, n_half:, :]
267258
b = output[:, :n_half, :]
268259
audio_1 = (audio_1 - b) / torch.exp(s)
@@ -308,7 +299,7 @@ def infer_onnx(self, spect, z, sigma=0.9):
308299
audio_0 = audio[:, :n_half, :]
309300
audio_1 = audio[:, n_half:(n_half+n_half), :]
310301

311-
output = self.WN[k]((audio_0, spect))
302+
output = self.WN[k](audio_0, spect)
312303
s = output[:, n_half:(n_half+n_half), :]
313304
b = output[:, :n_half, :]
314305
audio_1 = (audio_1 - b) / torch.exp(s)
@@ -323,6 +314,53 @@ def infer_onnx(self, spect, z, sigma=0.9):
323314

324315
return audio
325316

317+
def _infer_ts(self, spect, sigma : float=1.0):
318+
319+
spect = self.upsample(spect)
320+
# trim conv artifacts. maybe pad spec to kernel multiple
321+
time_cutoff = self.upsample.kernel_size[0] - self.upsample.stride[0]
322+
spect = spect[:, :, :-time_cutoff]
323+
324+
spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3)
325+
spect = spect.contiguous().view(spect.size(0), spect.size(1), -1)
326+
spect = spect.permute(0, 2, 1)
327+
328+
audio = torch.randn(spect.size(0), self.n_remaining_channels,
329+
spect.size(2), device=spect.device,
330+
dtype=spect.dtype)
331+
audio *= sigma
332+
333+
for kk, (wn, convinv) in enumerate(zip(self.WN_rev, self.convinv_rev)):
334+
k = self.n_flows - kk - 1
335+
n_half = int(audio.size(1) / 2)
336+
audio_0 = audio[:, :n_half, :]
337+
audio_1 = audio[:, n_half:, :]
338+
339+
output = wn(audio_0, spect)
340+
s = output[:, n_half:, :]
341+
b = output[:, :n_half, :]
342+
audio_1 = (audio_1 - b) / torch.exp(s)
343+
audio = torch.cat([audio_0, audio_1], 1)
344+
345+
audio = convinv.infer(audio)
346+
347+
if k % self.n_early_every == 0 and k > 0:
348+
z = torch.randn(spect.size(0), self.n_early_size,
349+
spect.size(2), device=spect.device,
350+
dtype=spect.dtype)
351+
audio = torch.cat((sigma * z, audio), 1)
352+
353+
return audio.permute(0, 2, 1).contiguous().view(audio.size(0), -1).data
354+
355+
def make_ts_scriptable(self, forward_is_infer=True):
356+
self.WN_rev = torch.nn.ModuleList(reversed(self.WN))
357+
self.convinv_rev = torch.nn.ModuleList(reversed(self.convinv))
358+
for conv in self.convinv_rev:
359+
conv._invert()
360+
361+
self.infer = self._infer_ts
362+
if forward_is_infer:
363+
self.forward = self._infer_ts
326364

327365
@staticmethod
328366
def remove_weightnorm(model):

0 commit comments

Comments
 (0)