@@ -140,6 +140,7 @@ def create_loader(
140140 pin_memory = False ,
141141 fp16 = False ,
142142 tf_preprocessing = False ,
143+ use_multi_epochs_loader = False
143144):
144145 re_num_splits = 0
145146 if re_split :
@@ -175,7 +176,12 @@ def create_loader(
175176 if collate_fn is None :
176177 collate_fn = fast_collate if use_prefetcher else torch .utils .data .dataloader .default_collate
177178
178- loader = torch .utils .data .DataLoader (
179+ loader_class = torch .utils .data .DataLoader
180+
181+ if use_multi_epochs_loader :
182+ loader_class = MultiEpochsDataLoader
183+
184+ loader = loader_class (
179185 dataset ,
180186 batch_size = batch_size ,
181187 shuffle = sampler is None and is_training ,
@@ -198,3 +204,35 @@ def create_loader(
198204 )
199205
200206 return loader
207+
208+
209+ class MultiEpochsDataLoader (torch .utils .data .DataLoader ):
210+
211+ def __init__ (self , * args , ** kwargs ):
212+ super ().__init__ (* args , ** kwargs )
213+ self ._DataLoader__initialized = False
214+ self .batch_sampler = _RepeatSampler (self .batch_sampler )
215+ self ._DataLoader__initialized = True
216+ self .iterator = super ().__iter__ ()
217+
218+ def __len__ (self ):
219+ return len (self .batch_sampler .sampler )
220+
221+ def __iter__ (self ):
222+ for i in range (len (self )):
223+ yield next (self .iterator )
224+
225+
226+ class _RepeatSampler (object ):
227+ """ Sampler that repeats forever.
228+
229+ Args:
230+ sampler (Sampler)
231+ """
232+
233+ def __init__ (self , sampler ):
234+ self .sampler = sampler
235+
236+ def __iter__ (self ):
237+ while True :
238+ yield from iter (self .sampler )
0 commit comments