|
| 1 | +""" |
| 2 | +Defines a class that is used to featurize audio clips, and provide |
| 3 | +them to the network for training or testing. |
| 4 | +""" |
| 5 | + |
| 6 | +import json |
| 7 | +import numpy as np |
| 8 | +import random |
| 9 | +from python_speech_features import mfcc |
| 10 | +import librosa |
| 11 | +import scipy.io.wavfile as wav |
| 12 | +import matplotlib.pyplot as plt |
| 13 | +from mpl_toolkits.axes_grid1 import make_axes_locatable |
| 14 | + |
| 15 | +from utils import calc_feat_dim, spectrogram_from_file, text_to_int_sequence |
| 16 | +from utils import conv_output_length |
| 17 | + |
| 18 | +RNG_SEED = 123 |
| 19 | + |
| 20 | +class AudioGenerator(): |
| 21 | + def __init__(self, step=10, window=20, max_freq=8000, mfcc_dim=13, |
| 22 | + minibatch_size=20, desc_file=None, spectrogram=True, max_duration=10.0, |
| 23 | + sort_by_duration=False): |
| 24 | + """ |
| 25 | + Params: |
| 26 | + step (int): Step size in milliseconds between windows (for spectrogram ONLY) |
| 27 | + window (int): FFT window size in milliseconds (for spectrogram ONLY) |
| 28 | + max_freq (int): Only FFT bins corresponding to frequencies between |
| 29 | + [0, max_freq] are returned (for spectrogram ONLY) |
| 30 | + desc_file (str, optional): Path to a JSON-line file that contains |
| 31 | + labels and paths to the audio files. If this is None, then |
| 32 | + load metadata right away |
| 33 | + """ |
| 34 | + |
| 35 | + self.feat_dim = calc_feat_dim(window, max_freq) |
| 36 | + self.mfcc_dim = mfcc_dim |
| 37 | + self.feats_mean = np.zeros((self.feat_dim,)) |
| 38 | + self.feats_std = np.ones((self.feat_dim,)) |
| 39 | + self.rng = random.Random(RNG_SEED) |
| 40 | + if desc_file is not None: |
| 41 | + self.load_metadata_from_desc_file(desc_file) |
| 42 | + self.step = step |
| 43 | + self.window = window |
| 44 | + self.max_freq = max_freq |
| 45 | + self.cur_train_index = 0 |
| 46 | + self.cur_valid_index = 0 |
| 47 | + self.cur_test_index = 0 |
| 48 | + self.max_duration=max_duration |
| 49 | + self.minibatch_size = minibatch_size |
| 50 | + self.spectrogram = spectrogram |
| 51 | + self.sort_by_duration = sort_by_duration |
| 52 | + |
| 53 | + def get_batch(self, partition): |
| 54 | + """ Obtain a batch of train, validation, or test data |
| 55 | + """ |
| 56 | + if partition == 'train': |
| 57 | + audio_paths = self.train_audio_paths |
| 58 | + cur_index = self.cur_train_index |
| 59 | + texts = self.train_texts |
| 60 | + elif partition == 'valid': |
| 61 | + audio_paths = self.valid_audio_paths |
| 62 | + cur_index = self.cur_valid_index |
| 63 | + texts = self.valid_texts |
| 64 | + elif partition == 'test': |
| 65 | + audio_paths = self.test_audio_paths |
| 66 | + cur_index = self.test_valid_index |
| 67 | + texts = self.test_texts |
| 68 | + else: |
| 69 | + raise Exception("Invalid partition. " |
| 70 | + "Must be train/validation") |
| 71 | + |
| 72 | + features = [self.normalize(self.featurize(a)) for a in |
| 73 | + audio_paths[cur_index:cur_index+self.minibatch_size]] |
| 74 | + |
| 75 | + # calculate necessary sizes |
| 76 | + max_length = max([features[i].shape[0] |
| 77 | + for i in range(0, self.minibatch_size)]) |
| 78 | + max_string_length = max([len(texts[cur_index+i]) |
| 79 | + for i in range(0, self.minibatch_size)]) |
| 80 | + |
| 81 | + # initialize the arrays |
| 82 | + X_data = np.zeros([self.minibatch_size, max_length, |
| 83 | + self.feat_dim*self.spectrogram + self.mfcc_dim*(not self.spectrogram)]) |
| 84 | + labels = np.ones([self.minibatch_size, max_string_length]) * 28 |
| 85 | + input_length = np.zeros([self.minibatch_size, 1]) |
| 86 | + label_length = np.zeros([self.minibatch_size, 1]) |
| 87 | + |
| 88 | + for i in range(0, self.minibatch_size): |
| 89 | + # calculate X_data & input_length |
| 90 | + feat = features[i] |
| 91 | + input_length[i] = feat.shape[0] |
| 92 | + X_data[i, :feat.shape[0], :] = feat |
| 93 | + |
| 94 | + # calculate labels & label_length |
| 95 | + label = np.array(text_to_int_sequence(texts[cur_index+i])) |
| 96 | + labels[i, :len(label)] = label |
| 97 | + label_length[i] = len(label) |
| 98 | + |
| 99 | + # return the arrays |
| 100 | + outputs = {'ctc': np.zeros([self.minibatch_size])} |
| 101 | + inputs = {'the_input': X_data, |
| 102 | + 'the_labels': labels, |
| 103 | + 'input_length': input_length, |
| 104 | + 'label_length': label_length |
| 105 | + } |
| 106 | + return (inputs, outputs) |
| 107 | + |
| 108 | + def shuffle_data_by_partition(self, partition): |
| 109 | + """ Shuffle the training or validation data |
| 110 | + """ |
| 111 | + if partition == 'train': |
| 112 | + self.train_audio_paths, self.train_durations, self.train_texts = shuffle_data( |
| 113 | + self.train_audio_paths, self.train_durations, self.train_texts) |
| 114 | + elif partition == 'valid': |
| 115 | + self.valid_audio_paths, self.valid_durations, self.valid_texts = shuffle_data( |
| 116 | + self.valid_audio_paths, self.valid_durations, self.valid_texts) |
| 117 | + else: |
| 118 | + raise Exception("Invalid partition. " |
| 119 | + "Must be train/validation") |
| 120 | + |
| 121 | + def sort_data_by_duration(self, partition): |
| 122 | + """ Sort the training or validation sets by (increasing) duration |
| 123 | + """ |
| 124 | + if partition == 'train': |
| 125 | + self.train_audio_paths, self.train_durations, self.train_texts = sort_data( |
| 126 | + self.train_audio_paths, self.train_durations, self.train_texts) |
| 127 | + elif partition == 'valid': |
| 128 | + self.valid_audio_paths, self.valid_durations, self.valid_texts = sort_data( |
| 129 | + self.valid_audio_paths, self.valid_durations, self.valid_texts) |
| 130 | + else: |
| 131 | + raise Exception("Invalid partition. " |
| 132 | + "Must be train/validation") |
| 133 | + |
| 134 | + def next_train(self): |
| 135 | + """ Obtain a batch of training data |
| 136 | + """ |
| 137 | + while True: |
| 138 | + ret = self.get_batch('train') |
| 139 | + self.cur_train_index += self.minibatch_size |
| 140 | + if self.cur_train_index >= len(self.train_texts) - self.minibatch_size: |
| 141 | + self.cur_train_index = 0 |
| 142 | + self.shuffle_data_by_partition('train') |
| 143 | + yield ret |
| 144 | + |
| 145 | + def next_valid(self): |
| 146 | + """ Obtain a batch of validation data |
| 147 | + """ |
| 148 | + while True: |
| 149 | + ret = self.get_batch('valid') |
| 150 | + self.cur_valid_index += self.minibatch_size |
| 151 | + if self.cur_valid_index >= len(self.valid_texts) - self.minibatch_size: |
| 152 | + self.cur_valid_index = 0 |
| 153 | + self.shuffle_data_by_partition('valid') |
| 154 | + yield ret |
| 155 | + |
| 156 | + def next_test(self): |
| 157 | + """ Obtain a batch of test data |
| 158 | + """ |
| 159 | + while True: |
| 160 | + ret = self.get_batch('test') |
| 161 | + self.cur_test_index += self.minibatch_size |
| 162 | + if self.cur_test_index >= len(self.test_texts) - self.minibatch_size: |
| 163 | + self.cur_test_index = 0 |
| 164 | + yield ret |
| 165 | + |
| 166 | + def load_train_data(self, desc_file='train_corpus.json'): |
| 167 | + self.load_metadata_from_desc_file(desc_file, 'train') |
| 168 | + self.fit_train() |
| 169 | + if self.sort_by_duration: |
| 170 | + self.sort_data_by_duration('train') |
| 171 | + |
| 172 | + def load_validation_data(self, desc_file='valid_corpus.json'): |
| 173 | + self.load_metadata_from_desc_file(desc_file, 'validation') |
| 174 | + if self.sort_by_duration: |
| 175 | + self.sort_data_by_duration('valid') |
| 176 | + |
| 177 | + def load_test_data(self, desc_file='test_corpus.json'): |
| 178 | + self.load_metadata_from_desc_file(desc_file, 'test') |
| 179 | + |
| 180 | + def load_metadata_from_desc_file(self, desc_file, partition): |
| 181 | + """ Read metadata from a JSON-line file |
| 182 | + (possibly takes long, depending on the filesize) |
| 183 | + Params: |
| 184 | + desc_file (str): Path to a JSON-line file that contains labels and |
| 185 | + paths to the audio files |
| 186 | + partition (str): One of 'train', 'validation' or 'test' |
| 187 | + """ |
| 188 | + audio_paths, durations, texts = [], [], [] |
| 189 | + with open(desc_file) as json_line_file: |
| 190 | + for line_num, json_line in enumerate(json_line_file): |
| 191 | + try: |
| 192 | + spec = json.loads(json_line) |
| 193 | + if float(spec['duration']) > self.max_duration: |
| 194 | + continue |
| 195 | + audio_paths.append(spec['key']) |
| 196 | + durations.append(float(spec['duration'])) |
| 197 | + texts.append(spec['text']) |
| 198 | + except Exception as e: |
| 199 | + # Change to (KeyError, ValueError) or |
| 200 | + # (KeyError,json.decoder.JSONDecodeError), depending on |
| 201 | + # json module version |
| 202 | + print('Error reading line #{}: {}' |
| 203 | + .format(line_num, json_line)) |
| 204 | + if partition == 'train': |
| 205 | + self.train_audio_paths = audio_paths |
| 206 | + self.train_durations = durations |
| 207 | + self.train_texts = texts |
| 208 | + elif partition == 'validation': |
| 209 | + self.valid_audio_paths = audio_paths |
| 210 | + self.valid_durations = durations |
| 211 | + self.valid_texts = texts |
| 212 | + elif partition == 'test': |
| 213 | + self.test_audio_paths = audio_paths |
| 214 | + self.test_durations = durations |
| 215 | + self.test_texts = texts |
| 216 | + else: |
| 217 | + raise Exception("Invalid partition to load metadata. " |
| 218 | + "Must be train/validation/test") |
| 219 | + |
| 220 | + def fit_train(self, k_samples=100): |
| 221 | + """ Estimate the mean and std of the features from the training set |
| 222 | + Params: |
| 223 | + k_samples (int): Use this number of samples for estimation |
| 224 | + """ |
| 225 | + k_samples = min(k_samples, len(self.train_audio_paths)) |
| 226 | + samples = self.rng.sample(self.train_audio_paths, k_samples) |
| 227 | + feats = [self.featurize(s) for s in samples] |
| 228 | + feats = np.vstack(feats) |
| 229 | + self.feats_mean = np.mean(feats, axis=0) |
| 230 | + self.feats_std = np.std(feats, axis=0) |
| 231 | + |
| 232 | + def featurize(self, audio_clip): |
| 233 | + """ For a given audio clip, calculate the corresponding feature |
| 234 | + Params: |
| 235 | + audio_clip (str): Path to the audio clip |
| 236 | + """ |
| 237 | + if self.spectrogram: |
| 238 | + return spectrogram_from_file( |
| 239 | + audio_clip, step=self.step, window=self.window, |
| 240 | + max_freq=self.max_freq) |
| 241 | + else: |
| 242 | + (rate, sig) = wav.read(audio_clip) |
| 243 | + return mfcc(sig, rate, numcep=self.mfcc_dim) |
| 244 | + |
| 245 | + def normalize(self, feature, eps=1e-14): |
| 246 | + """ Center a feature using the mean and std |
| 247 | + Params: |
| 248 | + feature (numpy.ndarray): Feature to normalize |
| 249 | + """ |
| 250 | + return (feature - self.feats_mean) / (self.feats_std + eps) |
| 251 | + |
| 252 | +def shuffle_data(audio_paths, durations, texts): |
| 253 | + """ Shuffle the data (called after making a complete pass through |
| 254 | + training or validation data during the training process) |
| 255 | + Params: |
| 256 | + audio_paths (list): Paths to audio clips |
| 257 | + durations (list): Durations of utterances for each audio clip |
| 258 | + texts (list): Sentences uttered in each audio clip |
| 259 | + """ |
| 260 | + p = np.random.permutation(len(audio_paths)) |
| 261 | + audio_paths = [audio_paths[i] for i in p] |
| 262 | + durations = [durations[i] for i in p] |
| 263 | + texts = [texts[i] for i in p] |
| 264 | + return audio_paths, durations, texts |
| 265 | + |
| 266 | +def sort_data(audio_paths, durations, texts): |
| 267 | + """ Sort the data by duration |
| 268 | + Params: |
| 269 | + audio_paths (list): Paths to audio clips |
| 270 | + durations (list): Durations of utterances for each audio clip |
| 271 | + texts (list): Sentences uttered in each audio clip |
| 272 | + """ |
| 273 | + p = np.argsort(durations).tolist() |
| 274 | + audio_paths = [audio_paths[i] for i in p] |
| 275 | + durations = [durations[i] for i in p] |
| 276 | + texts = [texts[i] for i in p] |
| 277 | + return audio_paths, durations, texts |
| 278 | + |
| 279 | +def vis_train_features(index=0): |
| 280 | + """ Visualizing the data point in the training set at the supplied index |
| 281 | + """ |
| 282 | + # obtain spectrogram |
| 283 | + audio_gen = AudioGenerator(spectrogram=True) |
| 284 | + audio_gen.load_train_data() |
| 285 | + vis_audio_path = audio_gen.train_audio_paths[index] |
| 286 | + vis_spectrogram_feature = audio_gen.normalize(audio_gen.featurize(vis_audio_path)) |
| 287 | + # obtain mfcc |
| 288 | + audio_gen = AudioGenerator(spectrogram=False) |
| 289 | + audio_gen.load_train_data() |
| 290 | + vis_mfcc_feature = audio_gen.normalize(audio_gen.featurize(vis_audio_path)) |
| 291 | + # obtain text label |
| 292 | + vis_text = audio_gen.train_texts[index] |
| 293 | + # obtain raw audio |
| 294 | + vis_raw_audio, _ = librosa.load(vis_audio_path) |
| 295 | + # print total number of training examples |
| 296 | + print('There are %d total training examples.' % len(audio_gen.train_audio_paths)) |
| 297 | + # return labels for plotting |
| 298 | + return vis_text, vis_raw_audio, vis_mfcc_feature, vis_spectrogram_feature, vis_audio_path |
| 299 | + |
| 300 | + |
| 301 | +def plot_raw_audio(vis_raw_audio): |
| 302 | + # plot the raw audio signal |
| 303 | + fig = plt.figure(figsize=(12,3)) |
| 304 | + ax = fig.add_subplot(111) |
| 305 | + steps = len(vis_raw_audio) |
| 306 | + ax.plot(np.linspace(1, steps, steps), vis_raw_audio) |
| 307 | + plt.title('Audio Signal') |
| 308 | + plt.xlabel('Time') |
| 309 | + plt.ylabel('Amplitude') |
| 310 | + plt.show() |
| 311 | + |
| 312 | +def plot_mfcc_feature(vis_mfcc_feature): |
| 313 | + # plot the MFCC feature |
| 314 | + fig = plt.figure(figsize=(12,5)) |
| 315 | + ax = fig.add_subplot(111) |
| 316 | + im = ax.imshow(vis_mfcc_feature, cmap=plt.cm.jet, aspect='auto') |
| 317 | + plt.title('Normalized MFCC') |
| 318 | + plt.ylabel('Time') |
| 319 | + plt.xlabel('MFCC Coefficient') |
| 320 | + divider = make_axes_locatable(ax) |
| 321 | + cax = divider.append_axes("right", size="5%", pad=0.05) |
| 322 | + plt.colorbar(im, cax=cax) |
| 323 | + ax.set_xticks(np.arange(0, 13, 2), minor=False); |
| 324 | + plt.show() |
| 325 | + |
| 326 | +def plot_spectrogram_feature(vis_spectrogram_feature): |
| 327 | + # plot the normalized spectrogram |
| 328 | + fig = plt.figure(figsize=(12,5)) |
| 329 | + ax = fig.add_subplot(111) |
| 330 | + im = ax.imshow(vis_spectrogram_feature, cmap=plt.cm.jet, aspect='auto') |
| 331 | + plt.title('Normalized Spectrogram') |
| 332 | + plt.ylabel('Time') |
| 333 | + plt.xlabel('Frequency') |
| 334 | + divider = make_axes_locatable(ax) |
| 335 | + cax = divider.append_axes("right", size="5%", pad=0.05) |
| 336 | + plt.colorbar(im, cax=cax) |
| 337 | + plt.show() |
| 338 | + |
0 commit comments