@@ -159,6 +159,7 @@ def _read_tfds(tfds_name: Text,
159159 cycle_length : Optional [int ] = None ,
160160 block_length : Optional [int ] = None ) -> tf .data .Dataset :
161161 """Reads a dataset from tfds."""
162+ repeat_filenames = is_training and not cache
162163 decoders = {}
163164 if tfds_skip_decoding_feature :
164165 for skip_feature in tfds_skip_decoding_feature .split (',' ):
@@ -170,6 +171,7 @@ def _read_tfds(tfds_name: Text,
170171 interleave_block_length = block_length ,
171172 input_context = input_context ,
172173 shuffle_seed = seed ,
174+ repeat_filenames = repeat_filenames ,
173175 skip_prefetch = True )
174176 dataset = tfds .load (name = tfds_name ,
175177 split = tfds_split ,
@@ -199,6 +201,7 @@ def _read_tfds(tfds_name: Text,
199201 interleave_block_length = block_length ,
200202 input_context = None ,
201203 shuffle_seed = seed ,
204+ repeat_filenames = repeat_filenames ,
202205 skip_prefetch = True )
203206 load_kwargs .update ({'read_config' : read_config })
204207 dataset = tfds .load (** load_kwargs )
@@ -210,12 +213,10 @@ def _read_tfds(tfds_name: Text,
210213 interleave_block_length = block_length ,
211214 input_context = input_context ,
212215 shuffle_seed = seed ,
216+ repeat_filenames = repeat_filenames ,
213217 skip_prefetch = True )
214218 load_kwargs .update ({'read_config' : read_config })
215219 dataset = tfds .load (** load_kwargs )
216-
217- if is_training and not cache :
218- dataset = dataset .repeat ()
219220 return dataset
220221
221222
0 commit comments