Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 30887b8

Browse files
Target string displayed; smaller fixes
1 parent 80e8a55 commit 30887b8

File tree

3 files changed

+47
-56
lines changed

3 files changed

+47
-56
lines changed

tensor2tensor/data_generators/text_encoder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,21 +37,21 @@
3737

3838
# Conversion between Unicode and UTF-8, if required (on Python2)
3939
native_to_unicode = (lambda s: s.decode("utf-8")) if PY2 else (lambda s: s)
40-
41-
4240
unicode_to_native = (lambda s: s.encode("utf-8")) if PY2 else (lambda s: s)
4341

4442

4543
# Reserved tokens for things like padding and EOS symbols.
4644
PAD = "<pad>"
4745
EOS = "<EOS>"
4846
RESERVED_TOKENS = [PAD, EOS]
47+
PAD_TOKEN = RESERVED_TOKENS.index(PAD) # Normally 0
48+
EOS_TOKEN = RESERVED_TOKENS.index(EOS) # Normally 1
49+
4950
if six.PY2:
5051
RESERVED_TOKENS_BYTES = RESERVED_TOKENS
5152
else:
5253
RESERVED_TOKENS_BYTES = [bytes(PAD, "ascii"), bytes(EOS, "ascii")]
5354

54-
5555
class TextEncoder(object):
5656
"""Base class for converting from ints to/from human readable strings."""
5757

tensor2tensor/models/transformer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,7 @@ def transformer_parsing_ice():
359359
hparams = transformer_parsing_base()
360360
hparams.batch_size = 4096
361361
hparams.batching_mantissa_bits = 2
362+
hparams.hidden_size = 512
362363
#hparams.max_length = 256
363364
#hparams.hidden_size = 128
364365
return hparams

tensor2tensor/utils/trainer_utils.py

100644100755
Lines changed: 43 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,15 @@
3030
# pylint: disable=redefined-builtin
3131
from six.moves import input
3232
from six.moves import xrange
33-
from six.moves import zip
3433
# pylint: enable=redefined-builtin
3534

3635
from tensor2tensor.data_generators import problem_hparams
36+
from tensor2tensor.data_generators.text_encoder import EOS_TOKEN
3737
from tensor2tensor.models import models # pylint: disable=unused-import
3838
from tensor2tensor.utils import data_reader
3939
from tensor2tensor.utils import expert_utils as eu
4040
from tensor2tensor.utils import metrics
4141
from tensor2tensor.utils import registry
42-
4342
import tensorflow as tf
4443
from tensorflow.contrib.learn.python.learn import learn_runner
4544
from tensorflow.python.ops import init_ops
@@ -120,6 +119,16 @@
120119
"<beam1>\t<beam2>..\t<input>")
121120

122121

122+
def _save_until_eos(hyp):
123+
""" Strips everything after the first <EOS> token, which is normally 1 """
124+
try:
125+
index = list(hyp).index(EOS_TOKEN)
126+
return hyp[0:index]
127+
except ValueError:
128+
# No EOS_TOKEN: return the array as-is
129+
return hyp
130+
131+
123132
def make_experiment_fn(data_dir, model_name, train_steps, eval_steps):
124133
"""Returns experiment_fn for learn_runner. Wraps create_experiment."""
125134

@@ -279,7 +288,6 @@ def session_config():
279288
rewrite_options=rewrite_options, infer_shapes=True)
280289
config = tf.ConfigProto(
281290
allow_soft_placement=True, graph_options=graph_options)
282-
283291
return config
284292

285293

@@ -345,6 +353,7 @@ def learning_rate_decay():
345353
lambda: decay,
346354
name="learning_rate_decay_warump_cond")
347355

356+
348357
def model_fn(features, targets, mode):
349358
"""Creates the prediction, loss, and train ops.
350359
@@ -356,10 +365,11 @@ def model_fn(features, targets, mode):
356365
Returns:
357366
A tuple consisting of the prediction, loss, and train_op.
358367
"""
359-
if mode == tf.contrib.learn.ModeKeys.INFER and FLAGS.decode_interactive:
360-
features = _interactive_input_tensor_to_features_dict(features, hparams)
361-
if mode == tf.contrib.learn.ModeKeys.INFER and FLAGS.decode_from_file:
362-
features = _decode_input_tensor_to_features_dict(features, hparams)
368+
if mode == tf.contrib.learn.ModeKeys.INFER:
369+
if FLAGS.decode_interactive:
370+
features = _interactive_input_tensor_to_features_dict(features, hparams)
371+
elif FLAGS.decode_from_file:
372+
features = _decode_input_tensor_to_features_dict(features, hparams)
363373
# A dictionary containing:
364374
# - problem_choice: A Tensor containing an integer indicating which problem
365375
# was selected for this run.
@@ -579,12 +589,14 @@ def log_fn(inputs,
579589
"%s_prediction_%d.jpg" % (problem, j))
580590
show_and_save_image(inputs / 255., save_path)
581591
elif inputs_vocab:
582-
decoded_inputs = inputs_vocab.decode(inputs.flatten())
592+
decoded_inputs = inputs_vocab.decode(_save_until_eos(inputs.flatten()))
583593
tf.logging.info("Inference results INPUT: %s" % decoded_inputs)
584594

585-
decoded_outputs = targets_vocab.decode(outputs.flatten())
586-
decoded_targets = targets_vocab.decode(targets.flatten())
595+
decoded_outputs = targets_vocab.decode(_save_until_eos(outputs.flatten()))
587596
tf.logging.info("Inference results OUTPUT: %s" % decoded_outputs)
597+
decoded_targets = targets_vocab.decode(_save_until_eos(targets.flatten()))
598+
tf.logging.info("Inference results TARGET: %s" % decoded_targets)
599+
588600
if FLAGS.decode_to_file:
589601
output_filepath = FLAGS.decode_to_file + ".outputs." + problem
590602
output_file = tf.gfile.Open(output_filepath, "a")
@@ -599,27 +611,16 @@ def log_fn(inputs,
599611
# iterator to log inputs and decodes.
600612
if FLAGS.decode_endless:
601613
tf.logging.info("Warning: Decoding endlessly")
602-
for j, result in enumerate(result_iter):
603-
inputs, targets, outputs = (result["inputs"], result["targets"],
604-
result["outputs"])
605-
if FLAGS.decode_return_beams:
606-
output_beams = np.split(outputs, FLAGS.decode_beam_size, axis=0)
607-
for k, beam in enumerate(output_beams):
608-
tf.logging.info("BEAM %d:" % k)
609-
log_fn(inputs, targets, beam, problem, j)
610-
else:
611-
log_fn(inputs, targets, outputs, problem, j)
612-
else:
613-
for j, (inputs, targets, outputs) in enumerate(
614-
zip(result_iter["inputs"], result_iter["targets"], result_iter[
615-
"outputs"])):
616-
if FLAGS.decode_return_beams:
617-
output_beams = np.split(outputs, FLAGS.decode_beam_size, axis=0)
618-
for k, beam in enumerate(output_beams):
619-
tf.logging.info("BEAM %d:" % k)
620-
log_fn(inputs, targets, beam, problem, j)
621-
else:
622-
log_fn(inputs, targets, outputs, problem, j)
614+
for j, result in enumerate(result_iter):
615+
inputs, targets, outputs = (result["inputs"], result["targets"],
616+
result["outputs"])
617+
if FLAGS.decode_return_beams:
618+
output_beams = np.split(outputs, FLAGS.decode_beam_size, axis=0)
619+
for k, beam in enumerate(output_beams):
620+
tf.logging.info("BEAM %d:" % k)
621+
log_fn(inputs, targets, beam, problem, j)
622+
else:
623+
log_fn(inputs, targets, outputs, problem, j)
623624

624625

625626
def decode_from_file(estimator, filename):
@@ -628,22 +629,12 @@ def decode_from_file(estimator, filename):
628629
problem_id = FLAGS.decode_problem_id
629630
inputs_vocab = hparams.problems[problem_id].vocabulary["inputs"]
630631
targets_vocab = hparams.problems[problem_id].vocabulary["targets"]
631-
tf.logging.info("Performing Decoding from a file.")
632+
tf.logging.info("Performing decoding from a file.")
632633
sorted_inputs, sorted_keys = _get_sorted_inputs(filename)
633634
num_decode_batches = (len(sorted_inputs) - 1) // FLAGS.decode_batch_size + 1
634635
input_fn = _decode_batch_input_fn(problem_id, num_decode_batches,
635636
sorted_inputs, inputs_vocab)
636637

637-
# strips everything after the first <EOS> id, which is assumed to be 1
638-
def _save_until_eos(hyp): # pylint: disable=missing-docstring
639-
ret = []
640-
index = 0
641-
# until you reach <EOS> id
642-
while index < len(hyp) and hyp[index] != 1:
643-
ret.append(hyp[index])
644-
index += 1
645-
return np.array(ret)
646-
647638
decodes = []
648639
for _ in range(num_decode_batches):
649640
result_iter = estimator.predict(
@@ -655,8 +646,7 @@ def log_fn(inputs, outputs):
655646
decoded_inputs = inputs_vocab.decode(_save_until_eos(inputs.flatten()))
656647
tf.logging.info("Inference results INPUT: %s" % decoded_inputs)
657648

658-
decoded_outputs = targets_vocab.decode(
659-
_save_until_eos(outputs.flatten()))
649+
decoded_outputs = targets_vocab.decode(_save_until_eos(outputs.flatten()))
660650
tf.logging.info("Inference results OUTPUT: %s" % decoded_outputs)
661651
return decoded_outputs
662652

@@ -667,7 +657,7 @@ def log_fn(inputs, outputs):
667657
for k, beam in enumerate(output_beams):
668658
tf.logging.info("BEAM %d:" % k)
669659
beam_decodes.append(log_fn(result["inputs"], beam))
670-
decodes.append(str.join("\t", beam_decodes))
660+
decodes.append("\t".join(beam_decodes))
671661

672662
else:
673663
decodes.append(log_fn(result["inputs"], result["outputs"]))
@@ -709,11 +699,11 @@ def decode_interactively(estimator):
709699
tf.logging.info("BEAM %d:" % k)
710700
if scores is not None:
711701
tf.logging.info("%s\tScore:%f" %
712-
(targets_vocab.decode(beam.flatten()), scores[k]))
702+
(targets_vocab.decode(_save_until_eos(beam.flatten())), scores[k]))
713703
else:
714-
tf.logging.info(targets_vocab.decode(beam.flatten()))
704+
tf.logging.info(targets_vocab.decode(_save_until_eos(beam.flatten())))
715705
else:
716-
tf.logging.info(targets_vocab.decode(result["outputs"].flatten()))
706+
tf.logging.info(targets_vocab.decode(_save_until_eos(result["outputs"].flatten())))
717707

718708

719709
def _decode_batch_input_fn(problem_id, num_decode_batches, sorted_inputs,
@@ -726,10 +716,10 @@ def _decode_batch_input_fn(problem_id, num_decode_batches, sorted_inputs,
726716
tf.logging.info("Deocding batch %d" % b)
727717
batch_length = 0
728718
batch_inputs = []
729-
for inputs in sorted_inputs[b * FLAGS.decode_batch_size:(
730-
b + 1) * FLAGS.decode_batch_size]:
719+
for inputs in sorted_inputs[b * FLAGS.decode_batch_size:
720+
(b + 1) * FLAGS.decode_batch_size]:
731721
input_ids = vocabulary.encode(inputs)
732-
input_ids.append(1) # Assuming EOS=1.
722+
input_ids.append(EOS_TOKEN)
733723
batch_inputs.append(input_ids)
734724
if len(input_ids) > batch_length:
735725
batch_length = len(input_ids)
@@ -822,7 +812,7 @@ def _interactive_input_fn(hparams):
822812
if input_type == "text":
823813
input_ids = vocabulary.encode(input_string)
824814
if has_input:
825-
input_ids.append(1) # assume 1 means end-of-source
815+
input_ids.append(EOS_TOKEN)
826816
x = [num_samples, decode_length, len(input_ids)] + input_ids
827817
assert len(x) < const_array_size
828818
x += [0] * (const_array_size - len(x))
@@ -1089,7 +1079,7 @@ def input_fn():
10891079
problem_choice = tf.to_int32(FLAGS.worker_id % problem_count)
10901080
else:
10911081
raise ValueError("Value of hparams.problem_choice is %s and must be "
1092-
"one of [uniform, adaptive, distributed]",
1082+
"one of [uniform, adaptive, distributed]" %
10931083
hparams.problem_choice)
10941084

10951085
# Inputs and targets conditional on problem_choice.

0 commit comments

Comments
 (0)