Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit c78abcd

Browse files
authored
Merge pull request #426 from martinpopel/decode-to-file
fix the semantics of decode_to_file
2 parents 8ca9709 + 985b637 commit c78abcd

File tree

2 files changed

+9
-13
lines changed

2 files changed

+9
-13
lines changed

tensor2tensor/bin/t2t-decoder

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,8 @@ flags = tf.flags
4747
FLAGS = flags.FLAGS
4848

4949
flags.DEFINE_string("output_dir", "", "Training directory to load from.")
50-
flags.DEFINE_string("decode_from_file", None, "Path to decode file")
51-
flags.DEFINE_string("decode_to_file", None,
52-
"Path prefix to inference output file")
50+
flags.DEFINE_string("decode_from_file", None, "Path to the source file for decoding")
51+
flags.DEFINE_string("decode_to_file", None, "Path to the decoded (output) file")
5352
flags.DEFINE_bool("decode_interactive", False,
5453
"Interactive local inference mode.")
5554
flags.DEFINE_integer("decode_shards", 1, "Number of decoding replicas.")

tensor2tensor/utils/decoding.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -252,17 +252,14 @@ def input_fn():
252252
# _decode_batch_input_fn
253253
sorted_inputs.reverse()
254254
decodes.reverse()
255-
# Dumping inputs and outputs to file filename.decodes in
256-
# format result\tinput in the same order as original inputs
257-
if decode_to_file:
258-
output_filename = decode_to_file
259-
else:
260-
output_filename = filename
255+
# If decode_to_file was provided use it as the output filename without any change
256+
# (except for adding shard_id if using more shards for decoding).
257+
# Otherwise, use the input filename plus model, hp, problem, beam, alpha.
258+
decode_filename = decode_to_file if decode_to_file else filename
261259
if decode_hp.shards > 1:
262-
base_filename = output_filename + ("%.2d" % decode_hp.shard_id)
263-
else:
264-
base_filename = output_filename
265-
decode_filename = _decode_filename(base_filename, problem_name, decode_hp)
260+
decode_filename = decode_filename + ("%.2d" % decode_hp.shard_id)
261+
if not decode_to_file:
262+
decode_filename = _decode_filename(decode_filename, problem_name, decode_hp)
266263
tf.logging.info("Writing decodes into %s" % decode_filename)
267264
outfile = tf.gfile.Open(decode_filename, "w")
268265
for index in range(len(sorted_inputs)):

0 commit comments

Comments
 (0)