Skip to content

Commit 3b72ebf

Browse files
authored
Merge pull request #140 from yoniaflalo/PR_MultiEpochsDataLoader
added MultiEpochsDataLoader
2 parents 8d8677e + a7f570c commit 3b72ebf

File tree

2 files changed

+42
-1
lines changed

2 files changed

+42
-1
lines changed

timm/data/loader.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ def create_loader(
140140
pin_memory=False,
141141
fp16=False,
142142
tf_preprocessing=False,
143+
use_multi_epochs_loader=False
143144
):
144145
re_num_splits = 0
145146
if re_split:
@@ -175,7 +176,12 @@ def create_loader(
175176
if collate_fn is None:
176177
collate_fn = fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate
177178

178-
loader = torch.utils.data.DataLoader(
179+
loader_class = torch.utils.data.DataLoader
180+
181+
if use_multi_epochs_loader:
182+
loader_class = MultiEpochsDataLoader
183+
184+
loader = loader_class(
179185
dataset,
180186
batch_size=batch_size,
181187
shuffle=sampler is None and is_training,
@@ -198,3 +204,35 @@ def create_loader(
198204
)
199205

200206
return loader
207+
208+
209+
class MultiEpochsDataLoader(torch.utils.data.DataLoader):
210+
211+
def __init__(self, *args, **kwargs):
212+
super().__init__(*args, **kwargs)
213+
self._DataLoader__initialized = False
214+
self.batch_sampler = _RepeatSampler(self.batch_sampler)
215+
self._DataLoader__initialized = True
216+
self.iterator = super().__iter__()
217+
218+
def __len__(self):
219+
return len(self.batch_sampler.sampler)
220+
221+
def __iter__(self):
222+
for i in range(len(self)):
223+
yield next(self.iterator)
224+
225+
226+
class _RepeatSampler(object):
227+
""" Sampler that repeats forever.
228+
229+
Args:
230+
sampler (Sampler)
231+
"""
232+
233+
def __init__(self, sampler):
234+
self.sampler = sampler
235+
236+
def __iter__(self):
237+
while True:
238+
yield from iter(self.sampler)

train.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,8 @@
198198
parser.add_argument('--tta', type=int, default=0, metavar='N',
199199
help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
200200
parser.add_argument("--local_rank", default=0, type=int)
201+
parser.add_argument('--use-multi-epochs-loader', action='store_true', default=False,
202+
help='use the multi-epochs-loader to save time at the beginning of every epoch')
201203

202204

203205
def _parse_args():
@@ -391,6 +393,7 @@ def main():
391393
distributed=args.distributed,
392394
collate_fn=collate_fn,
393395
pin_memory=args.pin_mem,
396+
use_multi_epochs_loader=args.use_multi_epochs_loader
394397
)
395398

396399
eval_dir = os.path.join(args.data, 'val')

0 commit comments

Comments
 (0)