1313# limitations under the License.
1414
1515"""A common dataset reader."""
16+ import dataclasses
1617import random
1718from typing import Any , Callable , Dict , List , Optional , Sequence , Text , Union
1819
@@ -159,20 +160,20 @@ def _read_tfds(tfds_name: Text,
159160 cycle_length : Optional [int ] = None ,
160161 block_length : Optional [int ] = None ) -> tf .data .Dataset :
161162 """Reads a dataset from tfds."""
162- repeat_filenames = is_training and not cache
163+ read_config = tfds .ReadConfig (
164+ interleave_cycle_length = cycle_length ,
165+ interleave_block_length = block_length ,
166+ input_context = input_context ,
167+ shuffle_seed = seed ,
168+ repeat_filenames = is_training and not cache ,
169+ skip_prefetch = True )
170+
163171 decoders = {}
164172 if tfds_skip_decoding_feature :
165173 for skip_feature in tfds_skip_decoding_feature .split (',' ):
166174 decoders [skip_feature .strip ()] = tfds .decode .SkipDecoding ()
167175
168176 if tfds_name .startswith ('mldataset.' ):
169- read_config = tfds .ReadConfig (
170- interleave_cycle_length = cycle_length ,
171- interleave_block_length = block_length ,
172- input_context = input_context ,
173- shuffle_seed = seed ,
174- repeat_filenames = repeat_filenames ,
175- skip_prefetch = True )
176177 dataset = tfds .load (name = tfds_name ,
177178 split = tfds_split ,
178179 as_supervised = tfds_as_supervised ,
@@ -196,25 +197,12 @@ def _read_tfds(tfds_name: Text,
196197 # The number of files in the dataset split is smaller than the number of
197198 # input pipelines. We read the entire dataset first and then shard in the
198199 # host memory.
199- read_config = tfds .ReadConfig (
200- interleave_cycle_length = cycle_length ,
201- interleave_block_length = block_length ,
202- input_context = None ,
203- shuffle_seed = seed ,
204- repeat_filenames = repeat_filenames ,
205- skip_prefetch = True )
200+ read_config = dataclasses .replace (read_config , input_context = None )
206201 load_kwargs .update ({'read_config' : read_config })
207202 dataset = tfds .load (** load_kwargs )
208203 dataset = dataset .shard (input_context .num_input_pipelines ,
209204 input_context .input_pipeline_id )
210205 else :
211- read_config = tfds .ReadConfig (
212- interleave_cycle_length = cycle_length ,
213- interleave_block_length = block_length ,
214- input_context = input_context ,
215- shuffle_seed = seed ,
216- repeat_filenames = repeat_filenames ,
217- skip_prefetch = True )
218206 load_kwargs .update ({'read_config' : read_config })
219207 dataset = tfds .load (** load_kwargs )
220208 return dataset
0 commit comments