1616"""Librispeech dataset."""
1717
1818import os
19- from subprocess import call
2019import tarfile
21- import wave
2220
2321# Dependency imports
2422
25- import numpy as np
26-
2723from tensor2tensor .data_generators import generator_utils
28- from tensor2tensor .data_generators import problem
29- from tensor2tensor .data_generators import text_encoder
30- from tensor2tensor .layers import common_layers
31- from tensor2tensor .utils import modality
24+ from tensor2tensor .data_generators import speech_recognition
3225from tensor2tensor .utils import registry
3326
34- import tensorflow as tf
35-
3627
3728_LIBRISPEECH_TRAIN_DATASETS = [
3829 [
@@ -86,130 +77,13 @@ def _collect_data(directory, input_ext, transcription_ext):
8677 return data_files
8778
8879
89- def _get_audio_data (filepath ):
90- # Construct a true .wav file.
91- out_filepath = filepath .strip (".flac" ) + ".wav"
92- # Assumes sox is installed on system. Sox converts from FLAC to WAV.
93- call (["sox" , filepath , out_filepath ])
94- wav_file = wave .open (open (out_filepath ))
95- frame_count = wav_file .getnframes ()
96- byte_array = wav_file .readframes (frame_count )
97-
98- data = np .fromstring (byte_array , np .uint8 ).tolist ()
99- return data , frame_count , wav_file .getsampwidth (), wav_file .getnchannels ()
100-
101-
102- class LibrispeechTextEncoder (text_encoder .TextEncoder ):
103-
104- def encode (self , s ):
105- return [self ._num_reserved_ids + ord (c ) for c in s ]
106-
107- def decode (self , ids ):
108- """Transform a sequence of int ids into a human-readable string.
109-
110- EOS is not expected in ids.
111-
112- Args:
113- ids: list of integers to be converted.
114- Returns:
115- s: human-readable string.
116- """
117- decoded_ids = []
118- for id_ in ids :
119- if 0 <= id_ < self ._num_reserved_ids :
120- decoded_ids .append (text_encoder .RESERVED_TOKENS [int (id_ )])
121- else :
122- decoded_ids .append (id_ - self ._num_reserved_ids )
123- return "" .join ([chr (d ) for d in decoded_ids ])
124-
125-
126- @registry .register_audio_modality
127- class LibrispeechModality (modality .Modality ):
128- """Performs strided conv compressions for audio spectral data."""
129-
130- def bottom (self , inputs ):
131- """Transform input from data space to model space.
132-
133- Args:
134- inputs: A Tensor with shape [batch, ...]
135- Returns:
136- body_input: A Tensor with shape [batch, ?, ?, body_input_depth].
137- """
138- with tf .variable_scope (self .name ):
139- # TODO(aidangomez): Will need to sort out a better audio pipeline
140- def xnet_resblock (x , filters , res_relu , name ):
141- with tf .variable_scope (name ):
142- # We only stride along the length dimension to preserve the spectral
143- # bins (which are tiny in dimensionality relative to length)
144- y = common_layers .separable_conv_block (
145- x ,
146- filters , [((1 , 1 ), (3 , 3 )), ((1 , 1 ), (3 , 3 ))],
147- first_relu = True ,
148- padding = "SAME" ,
149- force2d = True ,
150- name = "sep_conv_block" )
151- y = common_layers .pool (y , (3 , 3 ), "MAX" , "SAME" , strides = (2 , 1 ))
152- return y + common_layers .conv_block (
153- x ,
154- filters , [((1 , 1 ), (1 , 1 ))],
155- padding = "SAME" ,
156- strides = (2 , 1 ),
157- first_relu = res_relu ,
158- force2d = True ,
159- name = "res_conv0" )
160-
161- # Rescale from UINT8 to floats in [-1,-1]
162- signals = (tf .to_float (inputs )- 127 )/ 128.
163- signals = tf .squeeze (signals , [2 , 3 ])
164-
165- # `stfts` is a complex64 Tensor representing the short-time Fourier
166- # Transform of each signal in `signals`. Its shape is
167- # [batch_size, ?, fft_unique_bins]
168- # where fft_unique_bins = fft_length // 2 + 1 = 513.
169- stfts = tf .contrib .signal .stft (signals , frame_length = 1024 , frame_step = 512 ,
170- fft_length = 1024 )
171-
172- # An energy spectrogram is the magnitude of the complex-valued STFT.
173- # A float32 Tensor of shape [batch_size, ?, 513].
174- magnitude_spectrograms = tf .abs (stfts )
175-
176- # Warp the linear-scale, magnitude spectrograms into the mel-scale.
177- num_spectrogram_bins = magnitude_spectrograms .shape [- 1 ].value
178- lower_edge_hertz , upper_edge_hertz , num_mel_bins = 80.0 , 7600.0 , 64
179- sample_rate = 16000
180- linear_to_mel_weight_matrix = (
181- tf .contrib .signal .linear_to_mel_weight_matrix (
182- num_mel_bins , num_spectrogram_bins , sample_rate , lower_edge_hertz ,
183- upper_edge_hertz ))
184- mel_spectrograms = tf .tensordot (
185- magnitude_spectrograms , linear_to_mel_weight_matrix , 1 )
186- # Note: Shape inference for tensordot does not currently handle this case.
187- mel_spectrograms .set_shape (magnitude_spectrograms .shape [:- 1 ].concatenate (
188- linear_to_mel_weight_matrix .shape [- 1 :]))
189-
190- x = tf .expand_dims (mel_spectrograms , 2 )
191- x .set_shape ([None , None , None , num_mel_bins ])
192- for i in xrange (self ._model_hparams .audio_compression ):
193- x = xnet_resblock (x , 2 ** (i + 1 ), True , "compress_block_%d" % i )
194- return xnet_resblock (x , self ._body_input_depth , False ,
195- "compress_block_final" )
196-
197-
19880@registry .register_problem ()
199- class Librispeech (problem . Problem ):
200- """Problem spec for English word to dictionary definition ."""
81+ class Librispeech (speech_recognition . SpeechRecognitionProblem ):
82+ """Problem spec for Librispeech using clean and noisy data ."""
20183
202- @property
203- def is_character_level (self ):
204- return True
205-
206- @property
207- def input_space_id (self ):
208- return problem .SpaceID .AUDIO_SPECTRAL
209-
210- @property
211- def target_space_id (self ):
212- return problem .SpaceID .EN_CHR
84+ # Select only the clean data
85+ TRAIN_DATASETS = _LIBRISPEECH_TRAIN_DATASETS
86+ DEV_DATASETS = _LIBRISPEECH_TEST_DATASETS
21387
21488 @property
21589 def num_shards (self ):
@@ -228,26 +102,8 @@ def use_train_shards_for_dev(self):
228102 """If true, we only generate training data and hold out shards for dev."""
229103 return False
230104
231- def feature_encoders (self , _ ):
232- return {
233- "inputs" : text_encoder .TextEncoder (),
234- "targets" : LibrispeechTextEncoder (),
235- }
236-
237- def example_reading_spec (self ):
238- data_fields = {
239- "inputs" : tf .VarLenFeature (tf .int64 ),
240- "targets" : tf .VarLenFeature (tf .int64 ),
241- }
242- data_items_to_decoders = None
243- return (data_fields , data_items_to_decoders )
244-
245- def generator (self , data_dir , tmp_dir , training ,
105+ def generator (self , data_dir , tmp_dir , datasets ,
246106 eos_list = None , start_from = 0 , how_many = 0 ):
247- eos_list = [1 ] if eos_list is None else eos_list
248- datasets = (_LIBRISPEECH_TRAIN_DATASETS if training
249- else _LIBRISPEECH_TEST_DATASETS )
250- num_reserved_ids = self .feature_encoders (None )["targets" ].num_reserved_ids
251107 i = 0
252108 for url , subdir in datasets :
253109 filename = os .path .basename (url )
@@ -267,44 +123,53 @@ def generator(self, data_dir, tmp_dir, training,
267123 data_dir = os .path .join (tmp_dir , "LibriSpeech" , subdir )
268124 data_files = _collect_data (data_dir , "flac" , "txt" )
269125 data_pairs = data_files .values ()
126+
127+ encoders = self .feature_encoders (None )
128+ audio_encoder = encoders ["waveforms" ]
129+ text_encoder = encoders ["targets" ]
130+
270131 for media_file , text_data in sorted (data_pairs )[start_from :]:
271132 if how_many > 0 and i == how_many :
272133 return
273134 i += 1
274- audio_data , sample_count , sample_width , num_channels = _get_audio_data (
275- media_file )
276- label = [num_reserved_ids + ord (c ) for c in text_data ] + eos_list
277135 yield {
278- "inputs" : audio_data ,
279- "audio/channel_count" : [num_channels ],
280- "audio/sample_count" : [sample_count ],
281- "audio/sample_width" : [sample_width ],
282- "targets" : label
136+ "waveforms" : audio_encoder .encode (media_file ),
137+ "targets" : text_encoder .encode (text_data )
283138 }
284139
285140 def generate_data (self , data_dir , tmp_dir , task_id = - 1 ):
286141 train_paths = self .training_filepaths (
287142 data_dir , self .num_shards , shuffled = False )
288143 dev_paths = self .dev_filepaths (
289144 data_dir , self .num_dev_shards , shuffled = False )
145+
290146 if self .use_train_shards_for_dev :
291147 all_paths = train_paths + dev_paths
292148 generator_utils .generate_files (
293- self .generator (data_dir , tmp_dir , True ), all_paths )
149+ self .generator (data_dir , tmp_dir , self . TRAIN_DATASETS ), all_paths )
294150 generator_utils .shuffle_dataset (all_paths )
295151 else :
296152 generator_utils .generate_dataset_and_shuffle (
297- self .generator (data_dir , tmp_dir , True ), train_paths ,
298- self .generator (data_dir , tmp_dir , False ), dev_paths )
153+ self .generator (data_dir , tmp_dir , self . TRAIN_DATASETS ), train_paths ,
154+ self .generator (data_dir , tmp_dir , self . DEV_DATASETS ), dev_paths )
299155
300- def hparams (self , defaults , unused_model_hparams ):
301- p = defaults
302- p .stop_at_eos = int (False )
303- p .input_modality = {"inputs" : ("audio:librispeech_modality" , None )}
304- p .target_modality = (registry .Modalities .SYMBOL , 256 )
305156
306- def preprocess_example (self , example , mode , hparams ):
307- return example
157+ @registry .register_problem ()
158+ class LibrispeechCleanSmall (Librispeech ):
159+ """Problem spec for Librispeech using 100h clean train data."""
160+
161+ # Select only the clean data
162+ TRAIN_DATASETS = _LIBRISPEECH_TRAIN_DATASETS [:1 ]
163+ DEV_DATASETS = _LIBRISPEECH_TEST_DATASETS [:1 ]
164+
165+
166+ @registry .register_problem ()
167+ class LibrispeechClean (Librispeech ):
168+ """Problem spec for Librispeech using 460h clean train data."""
169+
170+ # Select only the clean data
171+ TRAIN_DATASETS = _LIBRISPEECH_TRAIN_DATASETS [:2 ]
172+ DEV_DATASETS = _LIBRISPEECH_TEST_DATASETS [:1 ]
308173
309174
310175# TODO(lukaszkaiser): clean up hparams or remove from here.
0 commit comments