Skip to content

Commit 0daaaa5

Browse files
committed
Add script to generate training data for wavenet vocoder
1 parent e513a76 commit 0daaaa5

File tree

1 file changed

+178
-0
lines changed

1 file changed

+178
-0
lines changed

generate_aligned_predictions.py

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
# coding: utf-8
2+
"""
3+
Generate ground trouth-aligned predictions
4+
5+
usage: generate_aligned_predictions.py [options] <checkpoint> <in_dir> <out_dir>
6+
7+
options:
8+
--hparams=<parmas> Hyper parameters [default: ].
9+
--overwrite Overwrite audio and mel outputs.
10+
-h, --help Show help message.
11+
"""
12+
from docopt import docopt
13+
import os
14+
from tqdm import tqdm
15+
import importlib
16+
from os.path import join
17+
from warnings import warn
18+
import sys
19+
20+
import numpy as np
21+
import torch
22+
from torch.autograd import Variable
23+
from torch import nn
24+
from torch.nn import functional as F
25+
26+
# The deepvoice3 model
27+
from deepvoice3_pytorch import frontend
28+
from hparams import hparams
29+
30+
use_cuda = torch.cuda.is_available()
31+
_frontend = None # to be set later
32+
33+
34+
def preprocess(model, in_dir, out_dir, text, audio_filename, mel_filename,
35+
p=0, speaker_id=None,
36+
fast=False):
37+
"""Generate ground truth-aligned prediction
38+
39+
The output of the network and corresponding audio are saved after time
40+
resolution adjastment if overwrite flag is specified.
41+
"""
42+
r = hparams.outputs_per_step
43+
downsample_step = hparams.downsample_step
44+
45+
if use_cuda:
46+
model = model.cuda()
47+
model.eval()
48+
if fast:
49+
model.make_generation_fast_()
50+
51+
mel_org = np.load(join(in_dir, mel_filename))
52+
mel = Variable(torch.from_numpy(mel_org)).unsqueeze(0).contiguous()
53+
54+
# Downsample mel spectrogram
55+
if downsample_step > 1:
56+
mel = mel[:, 0::downsample_step, :].contiguous()
57+
58+
decoder_target_len = mel.shape[1] // r
59+
s, e = 1, decoder_target_len + 1
60+
frame_positions = torch.arange(s, e).long().unsqueeze(0)
61+
frame_positions = Variable(frame_positions)
62+
63+
sequence = np.array(_frontend.text_to_sequence(text, p=p))
64+
sequence = Variable(torch.from_numpy(sequence)).unsqueeze(0)
65+
text_positions = torch.arange(1, sequence.size(-1) + 1).unsqueeze(0).long()
66+
text_positions = Variable(text_positions)
67+
speaker_ids = None if speaker_id is None else Variable(torch.LongTensor([speaker_id]))
68+
if use_cuda:
69+
sequence = sequence.cuda()
70+
text_positions = text_positions.cuda()
71+
speaker_ids = None if speaker_ids is None else speaker_ids.cuda()
72+
mel = mel.cuda()
73+
frame_positions = frame_positions.cuda()
74+
75+
# **Teacher forcing** decoding
76+
mel_outputs, _, _, _ = model(
77+
sequence, mel, text_positions=text_positions,
78+
frame_positions=frame_positions, speaker_ids=speaker_ids)
79+
80+
mel_output = mel_outputs[0].data.cpu().numpy()
81+
82+
# **Time resolution adjastment**
83+
# remove begenning audio used for first mel prediction
84+
wav = np.load(join(in_dir, audio_filename))[hparams.hop_size * downsample_step:]
85+
assert len(wav) % hparams.hop_size == 0
86+
87+
# Coarse upsample just for convenience
88+
# so that we can upsample conditional features by hop_size in wavenet
89+
if downsample_step > 0:
90+
mel_output = np.repeat(mel_output, downsample_step, axis=0)
91+
# downsampling -> upsampling, then we should have length equal to or larger than
92+
# the original mel length
93+
assert mel_output.shape[0] >= mel_org.shape[0]
94+
95+
# Trim mel output
96+
expected_frames = len(wav) // hparams.hop_size
97+
mel_output = mel_output[:expected_frames]
98+
99+
# Make sure we have correct lengths
100+
assert mel_output.shape[0] * hparams.hop_size == len(wav)
101+
102+
timesteps = len(wav)
103+
104+
# save
105+
np.save(join(out_dir, audio_filename), wav.astype(np.int16),
106+
allow_pickle=False)
107+
np.save(join(out_dir, mel_filename), mel_output.astype(np.float32),
108+
allow_pickle=False)
109+
110+
if speaker_id is None:
111+
return (audio_filename, mel_filename, timesteps, text)
112+
else:
113+
return (audio_filename, mel_filename, timesteps, text, speaker_id)
114+
115+
116+
def write_metadata(metadata, out_dir):
117+
with open(os.path.join(out_dir, 'train.txt'), 'w', encoding='utf-8') as f:
118+
for m in metadata:
119+
f.write('|'.join([str(x) for x in m]) + '\n')
120+
frames = sum([m[2] for m in metadata])
121+
sr = hparams.sample_rate
122+
hours = frames / sr / 3600
123+
print('Wrote %d utterances, %d time steps (%.2f hours)' % (len(metadata), frames, hours))
124+
print('Max input length: %d' % max(len(m[3]) for m in metadata))
125+
print('Max output length: %d' % max(m[2] for m in metadata))
126+
127+
128+
if __name__ == "__main__":
129+
args = docopt(__doc__)
130+
checkpoint_path = args["<checkpoint>"]
131+
in_dir = args["<in_dir>"]
132+
out_dir = args["<out_dir>"]
133+
134+
# Override hyper parameters
135+
hparams.parse(args["--hparams"])
136+
assert hparams.name == "deepvoice3"
137+
138+
# Presets
139+
if hparams.preset is not None and hparams.preset != "":
140+
preset = hparams.presets[hparams.preset]
141+
import json
142+
hparams.parse_json(json.dumps(preset))
143+
print("Override hyper parameters with preset \"{}\": {}".format(
144+
hparams.preset, json.dumps(preset, indent=4)))
145+
146+
_frontend = getattr(frontend, hparams.frontend)
147+
import train
148+
train._frontend = _frontend
149+
from train import build_model
150+
151+
model = build_model()
152+
153+
# Load checkpoint
154+
print("Load checkpoint from {}".format(checkpoint_path))
155+
checkpoint = torch.load(checkpoint_path)
156+
model.load_state_dict(checkpoint["state_dict"])
157+
158+
os.makedirs(out_dir, exist_ok=True)
159+
results = []
160+
with open(os.path.join(in_dir, "train.txt")) as f:
161+
lines = f.readlines()
162+
163+
for idx in tqdm(range(len(lines))):
164+
l = lines[idx]
165+
l = l[:-1].split("|")
166+
audio_filename, mel_filename, _, text = l[:4]
167+
speaker_id = int(l[4]) if len(l) > 4 else None
168+
if text == "N/A":
169+
raise RuntimeError("No transcription available")
170+
171+
result = preprocess(model, in_dir, out_dir, text, audio_filename,
172+
mel_filename, p=0,
173+
speaker_id=speaker_id, fast=True)
174+
results.append(result)
175+
176+
write_metadata(results, out_dir)
177+
178+
sys.exit(0)

0 commit comments

Comments
 (0)