@@ -52,7 +52,8 @@ def decode_hparams(overrides=""):
5252 return_beams = False ,
5353 max_input_size = - 1 ,
5454 identity_output = False ,
55- num_samples = - 1 )
55+ num_samples = - 1 ,
56+ delimiter = "\n " )
5657 hp = hp .parse (overrides )
5758 return hp
5859
@@ -176,8 +177,8 @@ def decode_from_dataset(estimator,
176177 # Write out predictions if decode_to_file passed
177178 if decode_to_file :
178179 for decoded_output , decoded_target in decoded_outputs :
179- output_file .write (str (decoded_output ) + " \n " )
180- target_file .write (str (decoded_target ) + " \n " )
180+ output_file .write (str (decoded_output ) + decode_hp . delimiter )
181+ target_file .write (str (decoded_target ) + decode_hp . delimiter )
181182
182183 if (decode_hp .num_samples >= 0 and
183184 num_predictions >= decode_hp .num_samples ):
@@ -203,7 +204,8 @@ def decode_from_file(estimator, filename, decode_hp, decode_to_file=None):
203204 targets_vocab = hparams .problems [problem_id ].vocabulary ["targets" ]
204205 problem_name = FLAGS .problems .split ("-" )[problem_id ]
205206 tf .logging .info ("Performing decoding from a file." )
206- sorted_inputs , sorted_keys = _get_sorted_inputs (filename , decode_hp .shards )
207+ sorted_inputs , sorted_keys = _get_sorted_inputs (filename , decode_hp .shards ,
208+ decode_hp .delimiter )
207209 num_decode_batches = (len (sorted_inputs ) - 1 ) // decode_hp .batch_size + 1
208210
209211 def input_fn ():
@@ -251,7 +253,7 @@ def input_fn():
251253 tf .logging .info ("Writing decodes into %s" % decode_filename )
252254 outfile = tf .gfile .Open (decode_filename , "w" )
253255 for index in range (len (sorted_inputs )):
254- outfile .write ("%s\n " % (decodes [sorted_keys [index ]]))
256+ outfile .write ("%s%s " % (decodes [sorted_keys [index ]], decode_hp . delimiter ))
255257
256258
257259def _decode_filename (base_filename , problem_name , decode_hp ):
@@ -472,13 +474,14 @@ def show_and_save_image(img, save_path):
472474 plt .savefig (save_path )
473475
474476
475- def _get_sorted_inputs (filename , num_shards = 1 ):
477+ def _get_sorted_inputs (filename , num_shards = 1 , delimiter = " \n " ):
476478 """Returning inputs sorted according to length.
477479
478480 Args:
479481 filename: path to file with inputs, 1 per line.
480482 num_shards: number of input shards. If > 1, will read from file filename.XX,
481483 where XX is FLAGS.worker_id.
484+ delimiter: str, delimits records in the file.
482485
483486 Returns:
484487 a sorted list of inputs
@@ -490,8 +493,12 @@ def _get_sorted_inputs(filename, num_shards=1):
490493 decode_filename = filename + ("%.2d" % FLAGS .worker_id )
491494 else :
492495 decode_filename = filename
493- inputs = [line .strip () for line in tf .gfile .Open (decode_filename )]
494- input_lens = [(i , len (line .strip ().split ())) for i , line in enumerate (inputs )]
496+
497+ with tf .gfile .Open (decode_filename ) as f :
498+ text = f .read ()
499+ records = text .split (delimiter )
500+ inputs = [record .strip () for record in records ]
501+ input_lens = [(i , len (line .split ())) for i , line in enumerate (inputs )]
495502 sorted_input_lens = sorted (input_lens , key = operator .itemgetter (1 ))
496503 # We'll need the keys to rearrange the inputs back into their original order
497504 sorted_keys = {}
@@ -553,8 +560,8 @@ def input_fn(problem_choice, x=inputs): # pylint: disable=missing-docstring
553560 feature_map ["problem_choice" ])
554561 features ["input_space_id" ] = input_space_id
555562 features ["target_space_id" ] = target_space_id
556- features ["decode_length" ] = (IMAGE_DECODE_LENGTH
557- if input_is_image else inputs [1 ])
563+ features ["decode_length" ] = (
564+ IMAGE_DECODE_LENGTH if input_is_image else inputs [1 ])
558565 features ["inputs" ] = x
559566 return features
560567
@@ -588,7 +595,7 @@ def input_fn(problem_choice, x=inputs): # pylint: disable=missing-docstring
588595 features ["problem_choice" ] = feature_map ["problem_choice" ]
589596 features ["input_space_id" ] = input_space_id
590597 features ["target_space_id" ] = target_space_id
591- features ["decode_length" ] = (IMAGE_DECODE_LENGTH
592- if input_is_image else tf .shape (x )[1 ] + 50 )
598+ features ["decode_length" ] = (
599+ IMAGE_DECODE_LENGTH if input_is_image else tf .shape (x )[1 ] + 50 )
593600 features ["inputs" ] = x
594601 return features
0 commit comments