3030# pylint: disable=redefined-builtin
3131from six .moves import input
3232from six .moves import xrange
33- from six .moves import zip
3433# pylint: enable=redefined-builtin
3534
3635from tensor2tensor .data_generators import problem_hparams
36+ from tensor2tensor .data_generators .text_encoder import EOS_TOKEN
3737from tensor2tensor .models import models # pylint: disable=unused-import
3838from tensor2tensor .utils import data_reader
3939from tensor2tensor .utils import expert_utils as eu
4040from tensor2tensor .utils import metrics
4141from tensor2tensor .utils import registry
42-
4342import tensorflow as tf
4443from tensorflow .contrib .learn .python .learn import learn_runner
4544from tensorflow .python .ops import init_ops
120119 "<beam1>\t <beam2>..\t <input>" )
121120
122121
122+ def _save_until_eos (hyp ):
123+ """ Strips everything after the first <EOS> token, which is normally 1 """
124+ try :
125+ index = list (hyp ).index (EOS_TOKEN )
126+ return hyp [0 :index ]
127+ except ValueError :
128+ # No EOS_TOKEN: return the array as-is
129+ return hyp
130+
131+
123132def make_experiment_fn (data_dir , model_name , train_steps , eval_steps ):
124133 """Returns experiment_fn for learn_runner. Wraps create_experiment."""
125134
@@ -279,7 +288,6 @@ def session_config():
279288 rewrite_options = rewrite_options , infer_shapes = True )
280289 config = tf .ConfigProto (
281290 allow_soft_placement = True , graph_options = graph_options )
282-
283291 return config
284292
285293
@@ -345,6 +353,7 @@ def learning_rate_decay():
345353 lambda : decay ,
346354 name = "learning_rate_decay_warump_cond" )
347355
356+
348357 def model_fn (features , targets , mode ):
349358 """Creates the prediction, loss, and train ops.
350359
@@ -356,10 +365,11 @@ def model_fn(features, targets, mode):
356365 Returns:
357366 A tuple consisting of the prediction, loss, and train_op.
358367 """
359- if mode == tf .contrib .learn .ModeKeys .INFER and FLAGS .decode_interactive :
360- features = _interactive_input_tensor_to_features_dict (features , hparams )
361- if mode == tf .contrib .learn .ModeKeys .INFER and FLAGS .decode_from_file :
362- features = _decode_input_tensor_to_features_dict (features , hparams )
368+ if mode == tf .contrib .learn .ModeKeys .INFER :
369+ if FLAGS .decode_interactive :
370+ features = _interactive_input_tensor_to_features_dict (features , hparams )
371+ elif FLAGS .decode_from_file :
372+ features = _decode_input_tensor_to_features_dict (features , hparams )
363373 # A dictionary containing:
364374 # - problem_choice: A Tensor containing an integer indicating which problem
365375 # was selected for this run.
@@ -579,12 +589,14 @@ def log_fn(inputs,
579589 "%s_prediction_%d.jpg" % (problem , j ))
580590 show_and_save_image (inputs / 255. , save_path )
581591 elif inputs_vocab :
582- decoded_inputs = inputs_vocab .decode (inputs .flatten ())
592+ decoded_inputs = inputs_vocab .decode (_save_until_eos ( inputs .flatten () ))
583593 tf .logging .info ("Inference results INPUT: %s" % decoded_inputs )
584594
585- decoded_outputs = targets_vocab .decode (outputs .flatten ())
586- decoded_targets = targets_vocab .decode (targets .flatten ())
595+ decoded_outputs = targets_vocab .decode (_save_until_eos (outputs .flatten ()))
587596 tf .logging .info ("Inference results OUTPUT: %s" % decoded_outputs )
597+ decoded_targets = targets_vocab .decode (_save_until_eos (targets .flatten ()))
598+ tf .logging .info ("Inference results TARGET: %s" % decoded_targets )
599+
588600 if FLAGS .decode_to_file :
589601 output_filepath = FLAGS .decode_to_file + ".outputs." + problem
590602 output_file = tf .gfile .Open (output_filepath , "a" )
@@ -599,27 +611,16 @@ def log_fn(inputs,
599611 # iterator to log inputs and decodes.
600612 if FLAGS .decode_endless :
601613 tf .logging .info ("Warning: Decoding endlessly" )
602- for j , result in enumerate (result_iter ):
603- inputs , targets , outputs = (result ["inputs" ], result ["targets" ],
604- result ["outputs" ])
605- if FLAGS .decode_return_beams :
606- output_beams = np .split (outputs , FLAGS .decode_beam_size , axis = 0 )
607- for k , beam in enumerate (output_beams ):
608- tf .logging .info ("BEAM %d:" % k )
609- log_fn (inputs , targets , beam , problem , j )
610- else :
611- log_fn (inputs , targets , outputs , problem , j )
612- else :
613- for j , (inputs , targets , outputs ) in enumerate (
614- zip (result_iter ["inputs" ], result_iter ["targets" ], result_iter [
615- "outputs" ])):
616- if FLAGS .decode_return_beams :
617- output_beams = np .split (outputs , FLAGS .decode_beam_size , axis = 0 )
618- for k , beam in enumerate (output_beams ):
619- tf .logging .info ("BEAM %d:" % k )
620- log_fn (inputs , targets , beam , problem , j )
621- else :
622- log_fn (inputs , targets , outputs , problem , j )
614+ for j , result in enumerate (result_iter ):
615+ inputs , targets , outputs = (result ["inputs" ], result ["targets" ],
616+ result ["outputs" ])
617+ if FLAGS .decode_return_beams :
618+ output_beams = np .split (outputs , FLAGS .decode_beam_size , axis = 0 )
619+ for k , beam in enumerate (output_beams ):
620+ tf .logging .info ("BEAM %d:" % k )
621+ log_fn (inputs , targets , beam , problem , j )
622+ else :
623+ log_fn (inputs , targets , outputs , problem , j )
623624
624625
625626def decode_from_file (estimator , filename ):
@@ -628,22 +629,12 @@ def decode_from_file(estimator, filename):
628629 problem_id = FLAGS .decode_problem_id
629630 inputs_vocab = hparams .problems [problem_id ].vocabulary ["inputs" ]
630631 targets_vocab = hparams .problems [problem_id ].vocabulary ["targets" ]
631- tf .logging .info ("Performing Decoding from a file." )
632+ tf .logging .info ("Performing decoding from a file." )
632633 sorted_inputs , sorted_keys = _get_sorted_inputs (filename )
633634 num_decode_batches = (len (sorted_inputs ) - 1 ) // FLAGS .decode_batch_size + 1
634635 input_fn = _decode_batch_input_fn (problem_id , num_decode_batches ,
635636 sorted_inputs , inputs_vocab )
636637
637- # strips everything after the first <EOS> id, which is assumed to be 1
638- def _save_until_eos (hyp ): # pylint: disable=missing-docstring
639- ret = []
640- index = 0
641- # until you reach <EOS> id
642- while index < len (hyp ) and hyp [index ] != 1 :
643- ret .append (hyp [index ])
644- index += 1
645- return np .array (ret )
646-
647638 decodes = []
648639 for _ in range (num_decode_batches ):
649640 result_iter = estimator .predict (
@@ -655,8 +646,7 @@ def log_fn(inputs, outputs):
655646 decoded_inputs = inputs_vocab .decode (_save_until_eos (inputs .flatten ()))
656647 tf .logging .info ("Inference results INPUT: %s" % decoded_inputs )
657648
658- decoded_outputs = targets_vocab .decode (
659- _save_until_eos (outputs .flatten ()))
649+ decoded_outputs = targets_vocab .decode (_save_until_eos (outputs .flatten ()))
660650 tf .logging .info ("Inference results OUTPUT: %s" % decoded_outputs )
661651 return decoded_outputs
662652
@@ -667,7 +657,7 @@ def log_fn(inputs, outputs):
667657 for k , beam in enumerate (output_beams ):
668658 tf .logging .info ("BEAM %d:" % k )
669659 beam_decodes .append (log_fn (result ["inputs" ], beam ))
670- decodes .append (str . join ( "\t " , beam_decodes ))
660+ decodes .append ("\t " . join ( beam_decodes ))
671661
672662 else :
673663 decodes .append (log_fn (result ["inputs" ], result ["outputs" ]))
@@ -709,11 +699,11 @@ def decode_interactively(estimator):
709699 tf .logging .info ("BEAM %d:" % k )
710700 if scores is not None :
711701 tf .logging .info ("%s\t Score:%f" %
712- (targets_vocab .decode (beam .flatten ()), scores [k ]))
702+ (targets_vocab .decode (_save_until_eos ( beam .flatten () )), scores [k ]))
713703 else :
714- tf .logging .info (targets_vocab .decode (beam .flatten ()))
704+ tf .logging .info (targets_vocab .decode (_save_until_eos ( beam .flatten () )))
715705 else :
716- tf .logging .info (targets_vocab .decode (result ["outputs" ].flatten ()))
706+ tf .logging .info (targets_vocab .decode (_save_until_eos ( result ["outputs" ].flatten () )))
717707
718708
719709def _decode_batch_input_fn (problem_id , num_decode_batches , sorted_inputs ,
@@ -726,10 +716,10 @@ def _decode_batch_input_fn(problem_id, num_decode_batches, sorted_inputs,
726716 tf .logging .info ("Deocding batch %d" % b )
727717 batch_length = 0
728718 batch_inputs = []
729- for inputs in sorted_inputs [b * FLAGS .decode_batch_size :(
730- b + 1 ) * FLAGS .decode_batch_size ]:
719+ for inputs in sorted_inputs [b * FLAGS .decode_batch_size :
720+ ( b + 1 ) * FLAGS .decode_batch_size ]:
731721 input_ids = vocabulary .encode (inputs )
732- input_ids .append (1 ) # Assuming EOS=1.
722+ input_ids .append (EOS_TOKEN )
733723 batch_inputs .append (input_ids )
734724 if len (input_ids ) > batch_length :
735725 batch_length = len (input_ids )
@@ -822,7 +812,7 @@ def _interactive_input_fn(hparams):
822812 if input_type == "text" :
823813 input_ids = vocabulary .encode (input_string )
824814 if has_input :
825- input_ids .append (1 ) # assume 1 means end-of-source
815+ input_ids .append (EOS_TOKEN )
826816 x = [num_samples , decode_length , len (input_ids )] + input_ids
827817 assert len (x ) < const_array_size
828818 x += [0 ] * (const_array_size - len (x ))
@@ -1089,7 +1079,7 @@ def input_fn():
10891079 problem_choice = tf .to_int32 (FLAGS .worker_id % problem_count )
10901080 else :
10911081 raise ValueError ("Value of hparams.problem_choice is %s and must be "
1092- "one of [uniform, adaptive, distributed]" ,
1082+ "one of [uniform, adaptive, distributed]" %
10931083 hparams .problem_choice )
10941084
10951085 # Inputs and targets conditional on problem_choice.
0 commit comments