|
1 | | -""" |
2 | | -Dynamic Sequence Length Datasets for Variable Resolution Image Processing |
| 1 | +""" Dynamic Sequence Length Datasets for Variable Resolution Image Processing |
3 | 2 |
|
4 | 3 | Implements two dataset wrappers: |
5 | | -1. DynamicSeqMapDataset - Map-style dataset that returns batches with variable sequence lengths |
6 | | -2. DynamicSeqIterDataset - Iterable dataset that yields batches with variable sequence lengths |
| 4 | +1. NaFlexMapDatasetWrapper - Map-style dataset that returns batches with variable sequence lengths |
| 5 | +TODO: 2. NaFlexIterableDatasetWrapper - Iterable dataset that yields batches with variable sequence lengths |
7 | 6 |
|
8 | 7 | Both support: |
9 | 8 | - Pre-initialized transforms for efficiency |
10 | 9 | - Distributed training |
11 | 10 | - Multiple workers |
12 | 11 | - Variable batch sizes based on sequence length |
| 12 | +
|
| 13 | +Hacked together by / Copyright 2025, Ross Wightman, Hugging Face |
13 | 14 | """ |
14 | 15 |
|
15 | 16 | import math |
|
20 | 21 |
|
21 | 22 | import torch |
22 | 23 | from torch.utils.data import Dataset, IterableDataset, DataLoader |
23 | | -from torchvision import transforms |
24 | 24 | from PIL import Image |
25 | 25 |
|
26 | | - |
27 | | -from .naflex_transforms import Patchify, patchify_image |
28 | | -from ..layers import to_2tuple |
| 26 | +from .naflex_transforms import Patchify |
| 27 | +from timm.layers import to_2tuple |
29 | 28 |
|
30 | 29 |
|
31 | 30 | def calculate_naflex_batch_size( |
@@ -203,7 +202,7 @@ class NaFlexMapDatasetWrapper(IterableDataset): |
203 | 202 | Yields batches with variable sequence lengths. It calculates a canonical |
204 | 203 | batch schedule (sequence length, batch size pairs) once based on the |
205 | 204 | total dataset size (padded for distribution). Each epoch, it shuffles |
206 | | - the *order* of this canonical schedule and the dataset indices. |
| 205 | + the order of this canonical schedule and the dataset indices. |
207 | 206 | This ensures a consistent number of batches and samples per epoch |
208 | 207 | across all ranks. Handles distributed training and multiple workers. |
209 | 208 | """ |
@@ -292,13 +291,13 @@ def __init__( |
292 | 291 |
|
293 | 292 | self.mixup_fn = mixup_fn |
294 | 293 |
|
295 | | - # --- Canonical Schedule Calculation (Done Once) --- |
| 294 | + # Canonical Schedule Calculation (Done Once) |
296 | 295 | self._canonical_batch_schedule: List[Tuple[int, int]] = [] |
297 | 296 | self._num_batches_per_rank: int = 0 |
298 | 297 | self._padded_samples_per_rank: int = 0 |
299 | 298 | self._create_canonical_schedule() # Calculate schedule based on padded size |
300 | 299 |
|
301 | | - # --- Per-Epoch State --- |
| 300 | + # Per-Epoch State |
302 | 301 | # Stores (seq_len, list_of_indices) for the current epoch, specific to this rank |
303 | 302 | self._epoch_batches: List[Tuple[int, List[int]]] = [] |
304 | 303 | self._prepare_epoch_batches(self.epoch) # setup for initial epoch |
@@ -420,7 +419,6 @@ def _prepare_epoch_batches(self, epoch: int): |
420 | 419 | if len(indices_for_ranks) != padded_total_len: |
421 | 420 | raise RuntimeError(f"Internal Error: Padded index list length {len(indices_for_ranks)} does not match expected {padded_total_len}") |
422 | 421 |
|
423 | | - |
424 | 422 | # 3. Select indices for the current rank |
425 | 423 | if self.distributed and self.world_size > 1: |
426 | 424 | indices_this_rank = indices_for_ranks[self.rank::self.world_size] |
|
0 commit comments