@@ -69,44 +69,11 @@ def input_fn(mode, params):
6969 },
7070 }
7171
72- def decode_record (record ):
73- """Serialized Example to dict of <feature name, Tensor>."""
74- data_fields , _ = problem .example_reading_spec ()
75- decoded = tf .parse_single_example (record , features = data_fields )
76- decoded ["inputs" ] = decoded ["inputs" ].values
77- decoded ["targets" ] = decoded ["targets" ].values
78- return decoded
79-
80- data_files = tf .contrib .slim .parallel_reader .get_data_files (
81- problem .filepattern (data_dir , mode ))
82- dataset = tf .data .TFRecordDataset (data_files )
83- dataset = dataset .map (decode_record , num_parallel_calls = num_threads )
84-
85- def _preprocess (example , problem , hparams , mode ):
86- example = problem .preprocess_example (example , mode , hparams )
87- # We do not want int64s as they are not supported on TPUs.
88- example = data_reader .cast_int64_to_int32 (example )
89- return example
90-
91- dataset = dataset .map (
92- lambda ex : _preprocess (ex , problem , hparams , mode ),
93- num_parallel_calls = num_threads )
94-
9572 def _valid_size (example ):
9673 return data_reader .example_valid_size (
9774 example , batching_scheme ["min_length" ], batching_scheme ["max_length" ])
9875
99- dataset = dataset .filter (_valid_size )
100- # TODO(rsepassi): In eval mode, should not repeat
101- dataset = dataset .repeat (None )
102- dataset = data_reader .padded_batch (dataset , batch_size ,
103- batching_scheme ["padded_shapes" ])
104-
105- if not is_training :
106- dataset = dataset .map (
107- lambda f : pad_batch (f , batch_size ), num_parallel_calls = num_threads )
108-
109- def shape_def (example ):
76+ def define_shapes (example ):
11077 """Set the right shapes for the features."""
11178 inputs = example ["inputs" ]
11279 targets = example ["targets" ]
@@ -130,7 +97,22 @@ def shape_def(example):
13097
13198 return example
13299
133- dataset = dataset .map (shape_def , num_parallel_calls = num_threads )
100+ dataset = problem .dataset (
101+ mode = mode , data_dir = data_dir , num_threads = num_threads , hparams = hparams )
102+ dataset = dataset .map (
103+ data_reader .cast_int64_to_int32 , num_threads = num_threads )
104+ dataset = dataset .filter (_valid_size )
105+ if is_training :
106+ dataset = dataset .shuffle (100 )
107+ # TODO(rsepassi): In eval mode, should not repeat. Do so because TPU seems
108+ # to crash if it runs out of data during eval.
109+ dataset = dataset .repeat (None )
110+ dataset = data_reader .padded_batch (dataset , batch_size ,
111+ batching_scheme ["padded_shapes" ])
112+ if not is_training :
113+ dataset = dataset .map (
114+ lambda f : pad_batch (f , batch_size ), num_parallel_calls = num_threads )
115+ dataset = dataset .map (define_shapes , num_parallel_calls = num_threads )
134116 dataset = dataset .prefetch (1 )
135117 features = dataset .make_one_shot_iterator ().get_next ()
136118
0 commit comments