2626# *****************************************************************************
2727import torch
2828torch ._C ._jit_set_autocast_mode (False )
29- from torch .autograd import Variable
29+ import torch .nn as nn
3030import torch .nn .functional as F
31+ from torch .autograd import Variable
3132
3233
3334@torch .jit .script
34- def fused_add_tanh_sigmoid_multiply (input_a , input_b , n_channels ):
35- n_channels_int = n_channels [ 0 ]
35+ def fused_add_tanh_sigmoid_multiply (input_a , input_b , n_channels : int ):
36+ n_channels_int = n_channels
3637 in_act = input_a + input_b
3738 t_act = torch .tanh (in_act [:, :n_channels_int , :])
3839 s_act = torch .sigmoid (in_act [:, n_channels_int :, :])
@@ -73,22 +74,14 @@ def forward(self, z):
7374 z = self .conv (z )
7475 return z , log_det_W
7576
76-
7777 def infer (self , z ):
78- # shape
79- batch_size , group_size , n_of_groups = z .size ()
80-
81- W = self .conv .weight .squeeze ()
78+ self ._invert ()
79+ return F .conv1d (z , self .W_inverse , bias = None , stride = 1 , padding = 0 )
8280
81+ def _invert (self ):
8382 if not hasattr (self , 'W_inverse' ):
84- # Reverse computation
85- W_inverse = W .float ().inverse ()
86- W_inverse = Variable (W_inverse [..., None ])
87- if z .type () == 'torch.cuda.HalfTensor' or z .type () == 'torch.HalfTensor' :
88- W_inverse = W_inverse .half ()
89- self .W_inverse = W_inverse
90- z = F .conv1d (z , self .W_inverse , bias = None , stride = 1 , padding = 0 )
91- return z
83+ W = self .conv .weight .squeeze ()
84+ self .W_inverse = W .float ().inverse ().unsqueeze (- 1 ).to (W .dtype )
9285
9386
9487class WN (torch .nn .Module ):
@@ -142,27 +135,25 @@ def __init__(self, n_in_channels, n_mel_channels, n_layers, n_channels,
142135 res_skip_layer , name = 'weight' )
143136 self .res_skip_layers .append (res_skip_layer )
144137
145- def forward (self , forward_input ):
146- audio , spect = forward_input
138+ def forward (self , audio , spect ):
147139 audio = self .start (audio )
148140
149- for i in range (self .n_layers ):
141+ output = 0
142+ for i , (in_layer , cond_layer , res_skip_layer ) in enumerate (
143+ zip (self .in_layers , self .cond_layers , self .res_skip_layers )):
150144 acts = fused_add_tanh_sigmoid_multiply (
151- self . in_layers [ i ] (audio ),
152- self . cond_layers [ i ] (spect ),
153- torch . IntTensor ([ self .n_channels ]) )
145+ in_layer (audio ),
146+ cond_layer (spect ),
147+ self .n_channels )
154148
155- res_skip_acts = self . res_skip_layers [ i ] (acts )
149+ res_skip_acts = res_skip_layer (acts )
156150 if i < self .n_layers - 1 :
157151 audio = res_skip_acts [:, :self .n_channels , :] + audio
158152 skip_acts = res_skip_acts [:, self .n_channels :, :]
159153 else :
160154 skip_acts = res_skip_acts
161155
162- if i == 0 :
163- output = skip_acts
164- else :
165- output = skip_acts + output
156+ output += skip_acts
166157 return self .end (output )
167158
168159
@@ -229,7 +220,7 @@ def forward(self, forward_input):
229220 audio_0 = audio [:, :n_half , :]
230221 audio_1 = audio [:, n_half :, :]
231222
232- output = self .WN [k ](( audio_0 , spect ) )
223+ output = self .WN [k ](audio_0 , spect )
233224 log_s = output [:, n_half :, :]
234225 b = output [:, :n_half , :]
235226 audio_1 = torch .exp (log_s ) * audio_1 + b
@@ -262,7 +253,7 @@ def infer(self, spect, sigma=1.0):
262253 audio_0 = audio [:, :n_half , :]
263254 audio_1 = audio [:, n_half :, :]
264255
265- output = self .WN [k ](( audio_0 , spect ) )
256+ output = self .WN [k ](audio_0 , spect )
266257 s = output [:, n_half :, :]
267258 b = output [:, :n_half , :]
268259 audio_1 = (audio_1 - b ) / torch .exp (s )
@@ -308,7 +299,7 @@ def infer_onnx(self, spect, z, sigma=0.9):
308299 audio_0 = audio [:, :n_half , :]
309300 audio_1 = audio [:, n_half :(n_half + n_half ), :]
310301
311- output = self .WN [k ](( audio_0 , spect ) )
302+ output = self .WN [k ](audio_0 , spect )
312303 s = output [:, n_half :(n_half + n_half ), :]
313304 b = output [:, :n_half , :]
314305 audio_1 = (audio_1 - b ) / torch .exp (s )
@@ -323,6 +314,53 @@ def infer_onnx(self, spect, z, sigma=0.9):
323314
324315 return audio
325316
317+ def _infer_ts (self , spect , sigma : float = 1.0 ):
318+
319+ spect = self .upsample (spect )
320+ # trim conv artifacts. maybe pad spec to kernel multiple
321+ time_cutoff = self .upsample .kernel_size [0 ] - self .upsample .stride [0 ]
322+ spect = spect [:, :, :- time_cutoff ]
323+
324+ spect = spect .unfold (2 , self .n_group , self .n_group ).permute (0 , 2 , 1 , 3 )
325+ spect = spect .contiguous ().view (spect .size (0 ), spect .size (1 ), - 1 )
326+ spect = spect .permute (0 , 2 , 1 )
327+
328+ audio = torch .randn (spect .size (0 ), self .n_remaining_channels ,
329+ spect .size (2 ), device = spect .device ,
330+ dtype = spect .dtype )
331+ audio *= sigma
332+
333+ for kk , (wn , convinv ) in enumerate (zip (self .WN_rev , self .convinv_rev )):
334+ k = self .n_flows - kk - 1
335+ n_half = int (audio .size (1 ) / 2 )
336+ audio_0 = audio [:, :n_half , :]
337+ audio_1 = audio [:, n_half :, :]
338+
339+ output = wn (audio_0 , spect )
340+ s = output [:, n_half :, :]
341+ b = output [:, :n_half , :]
342+ audio_1 = (audio_1 - b ) / torch .exp (s )
343+ audio = torch .cat ([audio_0 , audio_1 ], 1 )
344+
345+ audio = convinv .infer (audio )
346+
347+ if k % self .n_early_every == 0 and k > 0 :
348+ z = torch .randn (spect .size (0 ), self .n_early_size ,
349+ spect .size (2 ), device = spect .device ,
350+ dtype = spect .dtype )
351+ audio = torch .cat ((sigma * z , audio ), 1 )
352+
353+ return audio .permute (0 , 2 , 1 ).contiguous ().view (audio .size (0 ), - 1 ).data
354+
355+ def make_ts_scriptable (self , forward_is_infer = True ):
356+ self .WN_rev = torch .nn .ModuleList (reversed (self .WN ))
357+ self .convinv_rev = torch .nn .ModuleList (reversed (self .convinv ))
358+ for conv in self .convinv_rev :
359+ conv ._invert ()
360+
361+ self .infer = self ._infer_ts
362+ if forward_is_infer :
363+ self .forward = self ._infer_ts
326364
327365 @staticmethod
328366 def remove_weightnorm (model ):
0 commit comments