88
99from .constants import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
1010from .loader import _worker_init , adapt_to_chs
11- from .naflex_dataset import VariableSeqMapWrapper , NaFlexCollator
11+ from .naflex_dataset import NaFlexMapDatasetWrapper , NaFlexCollator
1212from .naflex_random_erasing import PatchRandomErasing
1313from .transforms_factory import create_transform
1414
@@ -215,7 +215,7 @@ def create_naflex_loader(
215215 if isinstance (dataset , torch .utils .data .IterableDataset ):
216216 assert False , "IterableDataset Wrapper is a WIP"
217217
218- naflex_dataset = VariableSeqMapWrapper (
218+ naflex_dataset = NaFlexMapDatasetWrapper (
219219 dataset ,
220220 transform_factory = transform_factory ,
221221 patch_size = patch_size ,
@@ -231,18 +231,12 @@ def create_naflex_loader(
231231 )
232232
233233 # NOTE: Collation is handled by the dataset wrapper for training
234- # Create the collator (handles fixed-size collation)
235- # collate_fn = NaFlexCollator(
236- # max_seq_len=max(seq_lens) + 1, # +1 for class token
237- # )
238-
239234 loader = torch .utils .data .DataLoader (
240235 naflex_dataset ,
241236 batch_size = None ,
242237 shuffle = False ,
243238 num_workers = num_workers ,
244239 sampler = None ,
245- #collate_fn=collate_fn,
246240 pin_memory = pin_memory ,
247241 worker_init_fn = partial (_worker_init , worker_seeding = worker_seeding ),
248242 persistent_workers = persistent_workers
0 commit comments