|
| 1 | +# coding: utf-8 |
| 2 | +""" |
| 3 | +Generate ground trouth-aligned predictions |
| 4 | +
|
| 5 | +usage: generate_aligned_predictions.py [options] <checkpoint> <in_dir> <out_dir> |
| 6 | +
|
| 7 | +options: |
| 8 | + --hparams=<parmas> Hyper parameters [default: ]. |
| 9 | + --overwrite Overwrite audio and mel outputs. |
| 10 | + -h, --help Show help message. |
| 11 | +""" |
| 12 | +from docopt import docopt |
| 13 | +import os |
| 14 | +from tqdm import tqdm |
| 15 | +import importlib |
| 16 | +from os.path import join |
| 17 | +from warnings import warn |
| 18 | +import sys |
| 19 | + |
| 20 | +import numpy as np |
| 21 | +import torch |
| 22 | +from torch.autograd import Variable |
| 23 | +from torch import nn |
| 24 | +from torch.nn import functional as F |
| 25 | + |
| 26 | +# The deepvoice3 model |
| 27 | +from deepvoice3_pytorch import frontend |
| 28 | +from hparams import hparams |
| 29 | + |
| 30 | +use_cuda = torch.cuda.is_available() |
| 31 | +_frontend = None # to be set later |
| 32 | + |
| 33 | + |
| 34 | +def preprocess(model, in_dir, out_dir, text, audio_filename, mel_filename, |
| 35 | + p=0, speaker_id=None, |
| 36 | + fast=False): |
| 37 | + """Generate ground truth-aligned prediction |
| 38 | +
|
| 39 | + The output of the network and corresponding audio are saved after time |
| 40 | + resolution adjastment if overwrite flag is specified. |
| 41 | + """ |
| 42 | + r = hparams.outputs_per_step |
| 43 | + downsample_step = hparams.downsample_step |
| 44 | + |
| 45 | + if use_cuda: |
| 46 | + model = model.cuda() |
| 47 | + model.eval() |
| 48 | + if fast: |
| 49 | + model.make_generation_fast_() |
| 50 | + |
| 51 | + mel_org = np.load(join(in_dir, mel_filename)) |
| 52 | + mel = Variable(torch.from_numpy(mel_org)).unsqueeze(0).contiguous() |
| 53 | + |
| 54 | + # Downsample mel spectrogram |
| 55 | + if downsample_step > 1: |
| 56 | + mel = mel[:, 0::downsample_step, :].contiguous() |
| 57 | + |
| 58 | + decoder_target_len = mel.shape[1] // r |
| 59 | + s, e = 1, decoder_target_len + 1 |
| 60 | + frame_positions = torch.arange(s, e).long().unsqueeze(0) |
| 61 | + frame_positions = Variable(frame_positions) |
| 62 | + |
| 63 | + sequence = np.array(_frontend.text_to_sequence(text, p=p)) |
| 64 | + sequence = Variable(torch.from_numpy(sequence)).unsqueeze(0) |
| 65 | + text_positions = torch.arange(1, sequence.size(-1) + 1).unsqueeze(0).long() |
| 66 | + text_positions = Variable(text_positions) |
| 67 | + speaker_ids = None if speaker_id is None else Variable(torch.LongTensor([speaker_id])) |
| 68 | + if use_cuda: |
| 69 | + sequence = sequence.cuda() |
| 70 | + text_positions = text_positions.cuda() |
| 71 | + speaker_ids = None if speaker_ids is None else speaker_ids.cuda() |
| 72 | + mel = mel.cuda() |
| 73 | + frame_positions = frame_positions.cuda() |
| 74 | + |
| 75 | + # **Teacher forcing** decoding |
| 76 | + mel_outputs, _, _, _ = model( |
| 77 | + sequence, mel, text_positions=text_positions, |
| 78 | + frame_positions=frame_positions, speaker_ids=speaker_ids) |
| 79 | + |
| 80 | + mel_output = mel_outputs[0].data.cpu().numpy() |
| 81 | + |
| 82 | + # **Time resolution adjastment** |
| 83 | + # remove begenning audio used for first mel prediction |
| 84 | + wav = np.load(join(in_dir, audio_filename))[hparams.hop_size * downsample_step:] |
| 85 | + assert len(wav) % hparams.hop_size == 0 |
| 86 | + |
| 87 | + # Coarse upsample just for convenience |
| 88 | + # so that we can upsample conditional features by hop_size in wavenet |
| 89 | + if downsample_step > 0: |
| 90 | + mel_output = np.repeat(mel_output, downsample_step, axis=0) |
| 91 | + # downsampling -> upsampling, then we should have length equal to or larger than |
| 92 | + # the original mel length |
| 93 | + assert mel_output.shape[0] >= mel_org.shape[0] |
| 94 | + |
| 95 | + # Trim mel output |
| 96 | + expected_frames = len(wav) // hparams.hop_size |
| 97 | + mel_output = mel_output[:expected_frames] |
| 98 | + |
| 99 | + # Make sure we have correct lengths |
| 100 | + assert mel_output.shape[0] * hparams.hop_size == len(wav) |
| 101 | + |
| 102 | + timesteps = len(wav) |
| 103 | + |
| 104 | + # save |
| 105 | + np.save(join(out_dir, audio_filename), wav.astype(np.int16), |
| 106 | + allow_pickle=False) |
| 107 | + np.save(join(out_dir, mel_filename), mel_output.astype(np.float32), |
| 108 | + allow_pickle=False) |
| 109 | + |
| 110 | + if speaker_id is None: |
| 111 | + return (audio_filename, mel_filename, timesteps, text) |
| 112 | + else: |
| 113 | + return (audio_filename, mel_filename, timesteps, text, speaker_id) |
| 114 | + |
| 115 | + |
| 116 | +def write_metadata(metadata, out_dir): |
| 117 | + with open(os.path.join(out_dir, 'train.txt'), 'w', encoding='utf-8') as f: |
| 118 | + for m in metadata: |
| 119 | + f.write('|'.join([str(x) for x in m]) + '\n') |
| 120 | + frames = sum([m[2] for m in metadata]) |
| 121 | + sr = hparams.sample_rate |
| 122 | + hours = frames / sr / 3600 |
| 123 | + print('Wrote %d utterances, %d time steps (%.2f hours)' % (len(metadata), frames, hours)) |
| 124 | + print('Max input length: %d' % max(len(m[3]) for m in metadata)) |
| 125 | + print('Max output length: %d' % max(m[2] for m in metadata)) |
| 126 | + |
| 127 | + |
| 128 | +if __name__ == "__main__": |
| 129 | + args = docopt(__doc__) |
| 130 | + checkpoint_path = args["<checkpoint>"] |
| 131 | + in_dir = args["<in_dir>"] |
| 132 | + out_dir = args["<out_dir>"] |
| 133 | + |
| 134 | + # Override hyper parameters |
| 135 | + hparams.parse(args["--hparams"]) |
| 136 | + assert hparams.name == "deepvoice3" |
| 137 | + |
| 138 | + # Presets |
| 139 | + if hparams.preset is not None and hparams.preset != "": |
| 140 | + preset = hparams.presets[hparams.preset] |
| 141 | + import json |
| 142 | + hparams.parse_json(json.dumps(preset)) |
| 143 | + print("Override hyper parameters with preset \"{}\": {}".format( |
| 144 | + hparams.preset, json.dumps(preset, indent=4))) |
| 145 | + |
| 146 | + _frontend = getattr(frontend, hparams.frontend) |
| 147 | + import train |
| 148 | + train._frontend = _frontend |
| 149 | + from train import build_model |
| 150 | + |
| 151 | + model = build_model() |
| 152 | + |
| 153 | + # Load checkpoint |
| 154 | + print("Load checkpoint from {}".format(checkpoint_path)) |
| 155 | + checkpoint = torch.load(checkpoint_path) |
| 156 | + model.load_state_dict(checkpoint["state_dict"]) |
| 157 | + |
| 158 | + os.makedirs(out_dir, exist_ok=True) |
| 159 | + results = [] |
| 160 | + with open(os.path.join(in_dir, "train.txt")) as f: |
| 161 | + lines = f.readlines() |
| 162 | + |
| 163 | + for idx in tqdm(range(len(lines))): |
| 164 | + l = lines[idx] |
| 165 | + l = l[:-1].split("|") |
| 166 | + audio_filename, mel_filename, _, text = l[:4] |
| 167 | + speaker_id = int(l[4]) if len(l) > 4 else None |
| 168 | + if text == "N/A": |
| 169 | + raise RuntimeError("No transcription available") |
| 170 | + |
| 171 | + result = preprocess(model, in_dir, out_dir, text, audio_filename, |
| 172 | + mel_filename, p=0, |
| 173 | + speaker_id=speaker_id, fast=True) |
| 174 | + results.append(result) |
| 175 | + |
| 176 | + write_metadata(results, out_dir) |
| 177 | + |
| 178 | + sys.exit(0) |
0 commit comments