Skip to content

Commit d78cbf4

Browse files
committed
Rename dataset wrapper to NaFlexMapDatasetWrapper
1 parent dd3b96c commit d78cbf4

File tree

3 files changed

+4
-10
lines changed

3 files changed

+4
-10
lines changed

timm/data/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from .imagenet_info import ImageNetInfo, infer_imagenet_subset
99
from .loader import create_loader
1010
from .mixup import Mixup, FastCollateMixup
11-
from .naflex_dataset import VariableSeqMapWrapper, calculate_naflex_batch_size
11+
from .naflex_dataset import NaFlexMapDatasetWrapper, calculate_naflex_batch_size
1212
from .naflex_loader import create_naflex_loader
1313
from .naflex_mixup import NaFlexMixup, pairwise_mixup_target, mix_batch_variable_size
1414
from .naflex_transforms import (

timm/data/naflex_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def __call__(self, batch):
123123
}, targets
124124

125125

126-
class VariableSeqMapWrapper(IterableDataset):
126+
class NaFlexMapDatasetWrapper(IterableDataset):
127127
"""
128128
IterableDataset wrapper for a map-style base dataset.
129129

timm/data/naflex_loader.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
1010
from .loader import _worker_init, adapt_to_chs
11-
from .naflex_dataset import VariableSeqMapWrapper, NaFlexCollator
11+
from .naflex_dataset import NaFlexMapDatasetWrapper, NaFlexCollator
1212
from .naflex_random_erasing import PatchRandomErasing
1313
from .transforms_factory import create_transform
1414

@@ -215,7 +215,7 @@ def create_naflex_loader(
215215
if isinstance(dataset, torch.utils.data.IterableDataset):
216216
assert False, "IterableDataset Wrapper is a WIP"
217217

218-
naflex_dataset = VariableSeqMapWrapper(
218+
naflex_dataset = NaFlexMapDatasetWrapper(
219219
dataset,
220220
transform_factory=transform_factory,
221221
patch_size=patch_size,
@@ -231,18 +231,12 @@ def create_naflex_loader(
231231
)
232232

233233
# NOTE: Collation is handled by the dataset wrapper for training
234-
# Create the collator (handles fixed-size collation)
235-
# collate_fn = NaFlexCollator(
236-
# max_seq_len=max(seq_lens) + 1, # +1 for class token
237-
# )
238-
239234
loader = torch.utils.data.DataLoader(
240235
naflex_dataset,
241236
batch_size=None,
242237
shuffle=False,
243238
num_workers=num_workers,
244239
sampler=None,
245-
#collate_fn=collate_fn,
246240
pin_memory=pin_memory,
247241
worker_init_fn=partial(_worker_init, worker_seeding=worker_seeding),
248242
persistent_workers=persistent_workers

0 commit comments

Comments
 (0)