Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 545ec34

Browse files
author
Ryan Sepassi
committed
Add support for custom record delimiter in decoding
PiperOrigin-RevId: 172128016
1 parent 43dbf4c commit 545ec34

File tree

1 file changed

+19
-12
lines changed

1 file changed

+19
-12
lines changed

tensor2tensor/utils/decoding.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ def decode_hparams(overrides=""):
5252
return_beams=False,
5353
max_input_size=-1,
5454
identity_output=False,
55-
num_samples=-1)
55+
num_samples=-1,
56+
delimiter="\n")
5657
hp = hp.parse(overrides)
5758
return hp
5859

@@ -176,8 +177,8 @@ def decode_from_dataset(estimator,
176177
# Write out predictions if decode_to_file passed
177178
if decode_to_file:
178179
for decoded_output, decoded_target in decoded_outputs:
179-
output_file.write(str(decoded_output) + "\n")
180-
target_file.write(str(decoded_target) + "\n")
180+
output_file.write(str(decoded_output) + decode_hp.delimiter)
181+
target_file.write(str(decoded_target) + decode_hp.delimiter)
181182

182183
if (decode_hp.num_samples >= 0 and
183184
num_predictions >= decode_hp.num_samples):
@@ -203,7 +204,8 @@ def decode_from_file(estimator, filename, decode_hp, decode_to_file=None):
203204
targets_vocab = hparams.problems[problem_id].vocabulary["targets"]
204205
problem_name = FLAGS.problems.split("-")[problem_id]
205206
tf.logging.info("Performing decoding from a file.")
206-
sorted_inputs, sorted_keys = _get_sorted_inputs(filename, decode_hp.shards)
207+
sorted_inputs, sorted_keys = _get_sorted_inputs(filename, decode_hp.shards,
208+
decode_hp.delimiter)
207209
num_decode_batches = (len(sorted_inputs) - 1) // decode_hp.batch_size + 1
208210

209211
def input_fn():
@@ -251,7 +253,7 @@ def input_fn():
251253
tf.logging.info("Writing decodes into %s" % decode_filename)
252254
outfile = tf.gfile.Open(decode_filename, "w")
253255
for index in range(len(sorted_inputs)):
254-
outfile.write("%s\n" % (decodes[sorted_keys[index]]))
256+
outfile.write("%s%s" % (decodes[sorted_keys[index]], decode_hp.delimiter))
255257

256258

257259
def _decode_filename(base_filename, problem_name, decode_hp):
@@ -472,13 +474,14 @@ def show_and_save_image(img, save_path):
472474
plt.savefig(save_path)
473475

474476

475-
def _get_sorted_inputs(filename, num_shards=1):
477+
def _get_sorted_inputs(filename, num_shards=1, delimiter="\n"):
476478
"""Returning inputs sorted according to length.
477479
478480
Args:
479481
filename: path to file with inputs, 1 per line.
480482
num_shards: number of input shards. If > 1, will read from file filename.XX,
481483
where XX is FLAGS.worker_id.
484+
delimiter: str, delimits records in the file.
482485
483486
Returns:
484487
a sorted list of inputs
@@ -490,8 +493,12 @@ def _get_sorted_inputs(filename, num_shards=1):
490493
decode_filename = filename + ("%.2d" % FLAGS.worker_id)
491494
else:
492495
decode_filename = filename
493-
inputs = [line.strip() for line in tf.gfile.Open(decode_filename)]
494-
input_lens = [(i, len(line.strip().split())) for i, line in enumerate(inputs)]
496+
497+
with tf.gfile.Open(decode_filename) as f:
498+
text = f.read()
499+
records = text.split(delimiter)
500+
inputs = [record.strip() for record in records]
501+
input_lens = [(i, len(line.split())) for i, line in enumerate(inputs)]
495502
sorted_input_lens = sorted(input_lens, key=operator.itemgetter(1))
496503
# We'll need the keys to rearrange the inputs back into their original order
497504
sorted_keys = {}
@@ -553,8 +560,8 @@ def input_fn(problem_choice, x=inputs): # pylint: disable=missing-docstring
553560
feature_map["problem_choice"])
554561
features["input_space_id"] = input_space_id
555562
features["target_space_id"] = target_space_id
556-
features["decode_length"] = (IMAGE_DECODE_LENGTH
557-
if input_is_image else inputs[1])
563+
features["decode_length"] = (
564+
IMAGE_DECODE_LENGTH if input_is_image else inputs[1])
558565
features["inputs"] = x
559566
return features
560567

@@ -588,7 +595,7 @@ def input_fn(problem_choice, x=inputs): # pylint: disable=missing-docstring
588595
features["problem_choice"] = feature_map["problem_choice"]
589596
features["input_space_id"] = input_space_id
590597
features["target_space_id"] = target_space_id
591-
features["decode_length"] = (IMAGE_DECODE_LENGTH
592-
if input_is_image else tf.shape(x)[1] + 50)
598+
features["decode_length"] = (
599+
IMAGE_DECODE_LENGTH if input_is_image else tf.shape(x)[1] + 50)
593600
features["inputs"] = x
594601
return features

0 commit comments

Comments
 (0)