Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 23129f2

Browse files
committed
Variety of fixes based on PR comments.
1 parent 98c7b41 commit 23129f2

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

tensor2tensor/data_generators/librispeech.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def _get_audio_data(filepath):
8282
class LibrispeechTextEncoder(text_encoder.TextEncoder):
8383

8484
def encode(self, s):
85-
return [ord[c] for c in s]
85+
return [self._num_reserved_ids + ord(c) for c in s]
8686

8787
def decode(self, ids):
8888
"""Transform a sequence of int ids into a human-readable string.
@@ -97,7 +97,7 @@ def decode(self, ids):
9797
if 0 <= id_ < self._num_reserved_ids:
9898
decoded_ids.append(RESERVED_TOKENS[int(id_)])
9999
else:
100-
decoded_ids.append(id_)
100+
decoded_ids.append(id_ - self._num_reserved_ids)
101101
return "".join([chr(d) for d in decoded_ids])
102102

103103

@@ -199,7 +199,7 @@ def target_space_id(self):
199199

200200
@property
201201
def num_shards(self):
202-
return 10
202+
return 100
203203

204204
@property
205205
def use_subword_tokenizer(self):
@@ -214,9 +214,9 @@ def use_train_shards_for_dev(self):
214214
"""If true, we only generate training data and hold out shards for dev."""
215215
return False
216216

217-
def feature_encoders(self, data_dir):
217+
def feature_encoders(self, _):
218218
return {
219-
"inputs": text_encoder.TextEncoder(), #None, #DoNothingEncoder(),
219+
"inputs": text_encoder.TextEncoder(),
220220
"targets": LibrispeechTextEncoder(),
221221
}
222222

@@ -233,8 +233,9 @@ def example_reading_spec(self):
233233

234234

235235
def generator(self, data_dir, tmp_dir, training, eos_list=None, start_from=0, how_many=0):
236-
eos_list = [1]
236+
eos_list = [1] if eos_list is None else eos_list
237237
datasets = (_LIBRISPEECH_TRAIN_DATASETS if training else _LIBRISPEECH_TEST_DATASETS)
238+
num_reserved_ids = self.feature_encoders(None)["targets"].num_reserved_ids
238239
i = 0
239240
for url, subdir in datasets:
240241
filename = os.path.basename(url)
@@ -260,7 +261,7 @@ def generator(self, data_dir, tmp_dir, training, eos_list=None, start_from=0, ho
260261
i += 1
261262
audio_data, sample_count, sample_width, num_channels = _get_audio_data(
262263
media_file)
263-
label = [ord(c) for c in text_data] + eos_list
264+
label = [num_reserved_ids + ord(c) for c in text_data] + eos_list
264265
yield {
265266
"inputs": audio_data,
266267
"audio/channel_count": [num_channels],

0 commit comments

Comments
 (0)