Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 92983ea

Browse files
authored
Merge pull request #419 from wingsbr/librispeech
Added a librispeech data generator.
2 parents c25e43f + 23129f2 commit 92983ea

File tree

3 files changed

+312
-1
lines changed

3 files changed

+312
-1
lines changed

tensor2tensor/bin/t2t-datagen

100644100755
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ _SUPPORTED_PROBLEM_GENERATORS = {
112112
vocab_filename="vocab.endefr.%d" % 2**15, vocab_size=2**15),
113113
lambda: audio.timit_generator(
114114
FLAGS.data_dir, FLAGS.tmp_dir, False, 626,
115-
vocab_filename="vocab.endefr.%d" % 2**15, vocab_size=2**15)),
115+
vocab_filename="vocab.endefr.%d" % 2**15, vocab_size=2**15)),
116116
}
117117

118118
# pylint: enable=g-long-lambda

tensor2tensor/data_generators/all_problems.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from tensor2tensor.data_generators import ice_parsing
2929
from tensor2tensor.data_generators import image
3030
from tensor2tensor.data_generators import imdb
31+
from tensor2tensor.data_generators import librispeech
3132
from tensor2tensor.data_generators import lm1b
3233
from tensor2tensor.data_generators import multinli
3334
from tensor2tensor.data_generators import problem_hparams
Lines changed: 310 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,310 @@
1+
from tensor2tensor.data_generators import problem
2+
from tensor2tensor.utils import registry
3+
from tensor2tensor.models import transformer
4+
from tensor2tensor.utils import modality
5+
from tensor2tensor.layers import common_layers
6+
from tensor2tensor.data_generators import text_encoder
7+
import random
8+
import tensorflow as tf
9+
import numpy as np
10+
from tensor2tensor.data_generators import generator_utils
11+
import os
12+
from subprocess import call
13+
import tarfile
14+
import wave
15+
16+
17+
_LIBRISPEECH_TRAIN_DATASETS = [
18+
[
19+
"http://www.openslr.org/resources/12/train-clean-100.tar.gz", # pylint: disable=line-too-long
20+
"train-clean-100"
21+
],
22+
[
23+
"http://www.openslr.org/resources/12/train-clean-360.tar.gz",
24+
"train-clean-360"
25+
],
26+
[
27+
"http://www.openslr.org/resources/12/train-other-500.tar.gz",
28+
"train-other-500"
29+
],
30+
]
31+
_LIBRISPEECH_TEST_DATASETS = [
32+
[
33+
"http://www.openslr.org/resources/12/dev-clean.tar.gz",
34+
"dev-clean"
35+
],
36+
[
37+
"http://www.openslr.org/resources/12/dev-other.tar.gz",
38+
"dev-other"
39+
],
40+
]
41+
42+
43+
def _collect_data(directory, input_ext, transcription_ext):
44+
"""Traverses directory collecting input and target files."""
45+
# Directory from string to tuple pair of strings
46+
# key: the filepath to a datafile including the datafile's basename. Example,
47+
# if the datafile was "/path/to/datafile.wav" then the key would be
48+
# "/path/to/datafile"
49+
# value: a pair of strings (media_filepath, label)
50+
data_files = dict()
51+
for root, _, filenames in os.walk(directory):
52+
transcripts = [filename for filename in filenames if transcription_ext in filename]
53+
for transcript in transcripts:
54+
basename = transcript.strip(transcription_ext)
55+
transcript_path = os.path.join(root, transcript)
56+
with open(transcript_path, 'r') as transcript_file:
57+
for transcript_line in transcript_file:
58+
line_contents = transcript_line.split(" ", 1)
59+
assert len(line_contents) == 2
60+
media_base, label = line_contents
61+
key = os.path.join(root, media_base)
62+
assert key not in data_files
63+
media_name = "%s.%s"%(media_base, input_ext)
64+
media_path = os.path.join(root, media_name)
65+
data_files[key] = (media_path, label)
66+
return data_files
67+
68+
69+
def _get_audio_data(filepath):
70+
# Construct a true .wav file.
71+
out_filepath = filepath.strip(".flac") + ".wav"
72+
# Assumes sox is installed on system. Sox converts from FLAC to WAV.
73+
call(["sox", filepath, out_filepath])
74+
wav_file = wave.open(open(out_filepath))
75+
frame_count = wav_file.getnframes()
76+
byte_array = wav_file.readframes(frame_count)
77+
78+
data = np.fromstring(byte_array, np.uint8).tolist()
79+
return data, frame_count, wav_file.getsampwidth(), wav_file.getnchannels()
80+
81+
82+
class LibrispeechTextEncoder(text_encoder.TextEncoder):
83+
84+
def encode(self, s):
85+
return [self._num_reserved_ids + ord(c) for c in s]
86+
87+
def decode(self, ids):
88+
"""Transform a sequence of int ids into a human-readable string.
89+
EOS is not expected in ids.
90+
Args:
91+
ids: list of integers to be converted.
92+
Returns:
93+
s: human-readable string.
94+
"""
95+
decoded_ids = []
96+
for id_ in ids:
97+
if 0 <= id_ < self._num_reserved_ids:
98+
decoded_ids.append(RESERVED_TOKENS[int(id_)])
99+
else:
100+
decoded_ids.append(id_ - self._num_reserved_ids)
101+
return "".join([chr(d) for d in decoded_ids])
102+
103+
104+
105+
@registry.register_audio_modality
106+
class LibrispeechModality(modality.Modality):
107+
"""Performs strided conv compressions for audio spectral data."""
108+
109+
def bottom(self, inputs):
110+
"""Transform input from data space to model space.
111+
Args:
112+
inputs: A Tensor with shape [batch, ...]
113+
Returns:
114+
body_input: A Tensor with shape [batch, ?, ?, body_input_depth].
115+
"""
116+
with tf.variable_scope(self.name):
117+
# TODO(aidangomez): Will need to sort out a better audio pipeline
118+
def xnet_resblock(x, filters, res_relu, name):
119+
with tf.variable_scope(name):
120+
# We only stride along the length dimension to preserve the spectral
121+
# bins (which are tiny in dimensionality relative to length)
122+
y = common_layers.separable_conv_block(
123+
x,
124+
filters, [((1, 1), (3, 3)), ((1, 1), (3, 3))],
125+
first_relu=True,
126+
padding="SAME",
127+
force2d=True,
128+
name="sep_conv_block")
129+
y = common_layers.pool(y, (3, 3), "MAX", "SAME", strides=(2, 1))
130+
return y + common_layers.conv_block(
131+
x,
132+
filters, [((1, 1), (1, 1))],
133+
padding="SAME",
134+
strides=(2, 1),
135+
first_relu=res_relu,
136+
force2d=True,
137+
name="res_conv0")
138+
139+
# Rescale from UINT8 to floats in [-1,-1]
140+
signals = (tf.to_float(inputs)-127)/128.
141+
#signals = tf.contrib.framework.nest.flatten(signals)
142+
signals = tf.squeeze(signals, [2, 3])
143+
144+
# `stfts` is a complex64 Tensor representing the Short-time Fourier Transform of
145+
# each signal in `signals`. Its shape is [batch_size, ?, fft_unique_bins]
146+
# where fft_unique_bins = fft_length // 2 + 1 = 513.
147+
stfts = tf.contrib.signal.stft(signals, frame_length=1024, frame_step=512,
148+
fft_length=1024)
149+
150+
# An energy spectrogram is the magnitude of the complex-valued STFT.
151+
# A float32 Tensor of shape [batch_size, ?, 513].
152+
magnitude_spectrograms = tf.abs(stfts)
153+
154+
log_offset = 1e-6
155+
log_magnitude_spectrograms = tf.log(magnitude_spectrograms + log_offset)
156+
157+
# Warp the linear-scale, magnitude spectrograms into the mel-scale.
158+
num_spectrogram_bins = magnitude_spectrograms.shape[-1].value
159+
lower_edge_hertz, upper_edge_hertz, num_mel_bins = 80.0, 7600.0, 64
160+
sample_rate = 16000
161+
linear_to_mel_weight_matrix = tf.contrib.signal.linear_to_mel_weight_matrix(
162+
num_mel_bins, num_spectrogram_bins, sample_rate, lower_edge_hertz,
163+
upper_edge_hertz)
164+
mel_spectrograms = tf.tensordot(
165+
magnitude_spectrograms, linear_to_mel_weight_matrix, 1)
166+
# Note: Shape inference for `tf.tensordot` does not currently handle this case.
167+
mel_spectrograms.set_shape(magnitude_spectrograms.shape[:-1].concatenate(
168+
linear_to_mel_weight_matrix.shape[-1:]))
169+
170+
# Try without the conversion to MFCCs, first.
171+
'''num_mfccs = 13
172+
# Keep the first `num_mfccs` MFCCs.
173+
mfccs = tf.contrib.signal.mfccs_from_log_mel_spectrograms(
174+
log_mel_spectrograms)[..., :num_mfccs]'''
175+
176+
x = tf.expand_dims(mel_spectrograms, 2)
177+
x.set_shape([None, None, None, num_mel_bins])
178+
for i in xrange(self._model_hparams.audio_compression):
179+
x = xnet_resblock(x, 2**(i + 1), True, "compress_block_%d" % i)
180+
return xnet_resblock(x, self._body_input_depth, False,
181+
"compress_block_final")
182+
183+
184+
@registry.register_problem()
185+
class Librispeech(problem.Problem):
186+
"""Problem spec for English word to dictionary definition."""
187+
188+
@property
189+
def is_character_level(self):
190+
return True
191+
192+
@property
193+
def input_space_id(self):
194+
return problem.SpaceID.AUDIO_SPECTRAL
195+
196+
@property
197+
def target_space_id(self):
198+
return problem.SpaceID.EN_CHR
199+
200+
@property
201+
def num_shards(self):
202+
return 100
203+
204+
@property
205+
def use_subword_tokenizer(self):
206+
return False
207+
208+
@property
209+
def num_dev_shards(self):
210+
return 1
211+
212+
@property
213+
def use_train_shards_for_dev(self):
214+
"""If true, we only generate training data and hold out shards for dev."""
215+
return False
216+
217+
def feature_encoders(self, _):
218+
return {
219+
"inputs": text_encoder.TextEncoder(),
220+
"targets": LibrispeechTextEncoder(),
221+
}
222+
223+
def example_reading_spec(self):
224+
data_fields = {
225+
"inputs": tf.VarLenFeature(tf.int64),
226+
#"audio/channel_count": tf.FixedLenFeature([], tf.int64),
227+
#"audio/sample_count": tf.FixedLenFeature([], tf.int64),
228+
#"audio/sample_width": tf.FixedLenFeature([], tf.int64),
229+
"targets": tf.VarLenFeature(tf.int64),
230+
}
231+
data_items_to_decoders = None
232+
return (data_fields, data_items_to_decoders)
233+
234+
235+
def generator(self, data_dir, tmp_dir, training, eos_list=None, start_from=0, how_many=0):
236+
eos_list = [1] if eos_list is None else eos_list
237+
datasets = (_LIBRISPEECH_TRAIN_DATASETS if training else _LIBRISPEECH_TEST_DATASETS)
238+
num_reserved_ids = self.feature_encoders(None)["targets"].num_reserved_ids
239+
i = 0
240+
for url, subdir in datasets:
241+
filename = os.path.basename(url)
242+
compressed_file = generator_utils.maybe_download(tmp_dir, filename, url)
243+
244+
read_type = "r:gz" if filename.endswith("tgz") else "r"
245+
with tarfile.open(compressed_file, read_type) as corpus_tar:
246+
# Create a subset of files that don't already exist.
247+
# tarfile.extractall errors when encountering an existing file
248+
# and tarfile.extract is extremely slow
249+
members = []
250+
for f in corpus_tar:
251+
if not os.path.isfile(os.path.join(tmp_dir, f.name)):
252+
members.append(f)
253+
corpus_tar.extractall(tmp_dir, members=members)
254+
255+
data_dir = os.path.join(tmp_dir, "LibriSpeech", subdir)
256+
data_files = _collect_data(data_dir, "flac", "txt")
257+
data_pairs = data_files.values()
258+
for media_file, text_data in sorted(data_pairs)[start_from:]:
259+
if how_many > 0 and i == how_many:
260+
return
261+
i += 1
262+
audio_data, sample_count, sample_width, num_channels = _get_audio_data(
263+
media_file)
264+
label = [num_reserved_ids + ord(c) for c in text_data] + eos_list
265+
yield {
266+
"inputs": audio_data,
267+
"audio/channel_count": [num_channels],
268+
"audio/sample_count": [sample_count],
269+
"audio/sample_width": [sample_width],
270+
"targets": label
271+
}
272+
273+
274+
def generate_data(self, data_dir, tmp_dir, task_id=-1):
275+
train_paths = self.training_filepaths(data_dir, self.num_shards, shuffled=False)
276+
dev_paths = self.dev_filepaths(data_dir, self.num_dev_shards, shuffled=False)
277+
if self.use_train_shards_for_dev:
278+
all_paths = train_paths + dev_paths
279+
generator_utils.generate_files(self.generator(data_dir, tmp_dir, True), all_paths)
280+
generator_utils.shuffle_dataset(all_paths)
281+
else:
282+
generator_utils.generate_dataset_and_shuffle(
283+
self.generator(data_dir, tmp_dir, True), train_paths,
284+
self.generator(data_dir, tmp_dir, False), dev_paths)
285+
286+
287+
def hparams(self, defaults, unused_model_hparams):
288+
p = defaults
289+
p.stop_at_eos = int(False)
290+
p.input_modality = { "inputs": ("audio:librispeech_modality", None) }
291+
p.target_modality = (registry.Modalities.SYMBOL, 256)
292+
293+
def preprocess_example(self, example, mode, hparams):
294+
return example
295+
296+
# TODO: clean up hparams
297+
@registry.register_hparams
298+
def librispeech_hparams():
299+
hparams = transformer.transformer_base_single_gpu() # Or whatever you'd like to build off.
300+
hparams.batch_size = 36
301+
hparams.audio_compression = 8
302+
hparams.hidden_size = 2048
303+
hparams.max_input_seq_length = 600000
304+
hparams.max_target_seq_length = 350
305+
hparams.max_length = hparams.max_input_seq_length
306+
hparams.min_length_bucket = hparams.max_input_seq_length // 2
307+
hparams.learning_rate = 0.05
308+
hparams.train_steps = 5000000
309+
hparams.num_hidden_layers = 4
310+
return hparams

0 commit comments

Comments
 (0)