Skip to content

Commit 0d43942

Browse files
committed
Add variable patch size to naflex training, improve patch size arg handling from train.py onwards. Add docstrings and type annotations (thanks Claude).
1 parent d78cbf4 commit 0d43942

File tree

8 files changed

+574
-188
lines changed

8 files changed

+574
-188
lines changed

timm/data/naflex_dataset.py

Lines changed: 165 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,28 @@
2525

2626

2727
from .naflex_transforms import Patchify, patchify_image
28+
from ..layers import to_2tuple
2829

2930

3031
def calculate_naflex_batch_size(
3132
tokens_per_batch: int,
3233
seq_len: int,
3334
max_size: Optional[int] = None,
3435
divisor: int = 1,
35-
rounding: str ='floor',
36-
):
37-
"""Calculate batch size based on sequence length with divisibility constraints."""
36+
rounding: str = 'floor',
37+
) -> int:
38+
"""Calculate batch size based on sequence length with divisibility constraints.
39+
40+
Args:
41+
tokens_per_batch: Target number of tokens per batch.
42+
seq_len: Sequence length for this batch.
43+
max_size: Optional maximum batch size.
44+
divisor: Ensure batch size is divisible by this value.
45+
rounding: Rounding method ('floor', 'ceil', 'round').
46+
47+
Returns:
48+
Calculated batch size.
49+
"""
3850
# Calculate raw batch size based on sequence length
3951
raw_batch_size = tokens_per_batch / seq_len
4052

@@ -64,14 +76,20 @@ class NaFlexCollator:
6476

6577
def __init__(
6678
self,
67-
max_seq_len=None,
68-
):
69-
self.max_seq_len = max_seq_len or 576 # Default ViT-B/16 sequence length (577 = 24*24)
79+
max_seq_len: Optional[int] = None,
80+
) -> None:
81+
"""Initialize NaFlexCollator.
7082
71-
def __call__(self, batch):
83+
Args:
84+
max_seq_len: Maximum sequence length for batching.
7285
"""
86+
self.max_seq_len = max_seq_len or 576 # Default ViT-B/16 sequence length (577 = 24*24)
87+
88+
def __call__(self, batch: List[Tuple[Dict[str, torch.Tensor], Union[int, torch.Tensor]]]) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
89+
"""Collate batch of NaFlex samples.
90+
7391
Args:
74-
batch: List of tuples (patch_dict, target)
92+
batch: List of tuples (patch_dict, target).
7593
7694
Returns:
7795
A tuple of (input_dict, targets) where input_dict contains:
@@ -99,11 +117,20 @@ def __call__(self, batch):
99117
# Find the maximum number of patches in this batch
100118
max_patches = max(item['patches'].shape[0] for item in patch_dicts)
101119

102-
# Get patch dimensionality
103-
patch_dim = patch_dicts[0]['patches'].shape[1]
120+
# Check if patches are flattened or unflattened
121+
patches_tensor = patch_dicts[0]['patches']
122+
is_unflattened = patches_tensor.ndim == 4 # [N, Ph, Pw, C]
123+
124+
if is_unflattened:
125+
# Patches are [N, Ph, Pw, C] - variable patch size mode
126+
_, ph, pw, c = patches_tensor.shape
127+
patches = torch.zeros((batch_size, max_patches, ph, pw, c), dtype=torch.float32)
128+
else:
129+
# Patches are [N, P*P*C] - normal mode
130+
patch_dim = patches_tensor.shape[1]
131+
patches = torch.zeros((batch_size, max_patches, patch_dim), dtype=torch.float32)
104132

105-
# Prepare tensors for the batch
106-
patches = torch.zeros((batch_size, max_patches, patch_dim), dtype=torch.float32)
133+
# Prepare other tensors
107134
patch_coord = torch.zeros((batch_size, max_patches, 2), dtype=torch.int64) # [B, N, 2] for (y, x)
108135
patch_valid = torch.zeros((batch_size, max_patches), dtype=torch.bool)
109136

@@ -115,12 +142,58 @@ def __call__(self, batch):
115142
patch_coord[i, :num_patches] = patch_dict['patch_coord'][:num_patches]
116143
patch_valid[i, :num_patches] = patch_dict['patch_valid'][:num_patches]
117144

118-
return {
145+
result = {
119146
'patches': patches,
120147
'patch_coord': patch_coord,
121148
'patch_valid': patch_valid,
122149
'seq_len': max_patches,
123-
}, targets
150+
}
151+
152+
return result, targets
153+
154+
155+
def _resolve_patch_cfg(
156+
patch_size: Optional[Union[int, Tuple[int, int]]],
157+
patch_size_choices: Optional[List[int]],
158+
patch_size_choice_probs: Optional[List[float]],
159+
) -> Tuple[List[Tuple[int, int]], List[float], bool]:
160+
"""Resolve patch size configuration.
161+
162+
Args:
163+
patch_size: Single patch size to use.
164+
patch_size_choices: List of patch sizes to choose from.
165+
patch_size_choice_probs: Probabilities for each patch size choice.
166+
167+
Returns:
168+
Tuple of (sizes, probs, variable) where sizes is list of patch size tuples,
169+
probs is list of probabilities, and variable indicates if patch size varies.
170+
"""
171+
# If both are None, default to patch_size=16
172+
if patch_size is None and patch_size_choices is None:
173+
patch_size = 16
174+
175+
if (patch_size is None) == (patch_size_choices is None):
176+
raise ValueError(
177+
"Specify exactly one of `patch_size` or `patch_size_choices`."
178+
)
179+
180+
if patch_size is not None:
181+
sizes = [to_2tuple(patch_size)]
182+
probs = [1.0]
183+
variable = False
184+
else:
185+
sizes = [to_2tuple(p) for p in patch_size_choices]
186+
if patch_size_choice_probs is None:
187+
probs = [1.0 / len(sizes)] * len(sizes)
188+
else:
189+
if len(patch_size_choice_probs) != len(sizes):
190+
raise ValueError("`patch_size_choice_probs` length mismatch.")
191+
s = float(sum(patch_size_choice_probs))
192+
if s <= 0:
193+
raise ValueError("`patch_size_choice_probs` sum to zero.")
194+
probs = [p / s for p in patch_size_choice_probs]
195+
variable = True
196+
return sizes, probs, variable
124197

125198

126199
class NaFlexMapDatasetWrapper(IterableDataset):
@@ -138,9 +211,11 @@ class NaFlexMapDatasetWrapper(IterableDataset):
138211
def __init__(
139212
self,
140213
base_dataset: Dataset,
141-
patch_size: Union[int, Tuple[int, int]] = 16,
142-
seq_lens: List[int] = (128, 256, 576, 784, 1024),
143-
max_tokens_per_batch: int = 4096 * 4, # Example: 16k tokens
214+
patch_size: Optional[Union[int, Tuple[int, int]]] = None,
215+
patch_size_choices: Optional[List[int]] = None,
216+
patch_size_choice_probs: Optional[List[float]] = None,
217+
seq_lens: Tuple[int, ...] = (128, 256, 576, 784, 1024),
218+
max_tokens_per_batch: int = 4096 * 4,
144219
transform_factory: Optional[Callable] = None,
145220
mixup_fn: Optional[Callable] = None,
146221
seed: int = 42,
@@ -149,14 +224,32 @@ def __init__(
149224
rank: int = 0,
150225
world_size: int = 1,
151226
epoch: int = 0,
152-
batch_divisor: int = 8, # Ensure batch size is divisible by this
153-
):
227+
batch_divisor: int = 8,
228+
) -> None:
229+
"""Initialize NaFlexMapDatasetWrapper.
230+
231+
Args:
232+
base_dataset: Map-style dataset to wrap.
233+
patch_size: Single patch size to use.
234+
patch_size_choices: List of patch sizes to randomly select from.
235+
patch_size_choice_probs: Probabilities for each patch size.
236+
seq_lens: Sequence lengths to use for batching.
237+
max_tokens_per_batch: Target tokens per batch.
238+
transform_factory: Factory function for creating transforms.
239+
mixup_fn: Optional mixup function.
240+
seed: Random seed.
241+
shuffle: Whether to shuffle data.
242+
distributed: Whether using distributed training.
243+
rank: Process rank for distributed training.
244+
world_size: Total number of processes.
245+
epoch: Starting epoch.
246+
batch_divisor: Ensure batch size is divisible by this.
247+
"""
154248
super().__init__()
155249
if not hasattr(base_dataset, '__len__') or not hasattr(base_dataset, '__getitem__'):
156250
raise TypeError("base_dataset must be a map-style dataset (implement __len__ and __getitem__)")
157251

158252
self.base_dataset = base_dataset
159-
self.patch_size = patch_size if isinstance(patch_size, tuple) else (patch_size, patch_size)
160253
self.seq_lens = sorted(list(set(seq_lens))) # Ensure unique and sorted
161254
self.max_tokens_per_batch = max_tokens_per_batch
162255
self.seed = seed
@@ -167,17 +260,37 @@ def __init__(
167260
self.epoch = epoch
168261
self.batch_divisor = batch_divisor
169262

170-
# Pre-initialize transforms and collate fns for each sequence length
171-
self.transforms: Dict[int, Optional[Callable]] = {}
263+
# Resolve patch size configuration
264+
self.patch_sizes, self.patch_size_probs, self.variable_patch_size = _resolve_patch_cfg(
265+
patch_size,
266+
patch_size_choices,
267+
patch_size_choice_probs
268+
)
269+
270+
# Pre-initialize transforms and collate fns for each (seq_len, patch_idx) combination
271+
self.transforms: Dict[Tuple[int, int], Optional[Callable]] = {}
172272
self.collate_fns: Dict[int, Callable] = {}
273+
self.patchifiers: List[Callable] = []
274+
173275
for seq_len in self.seq_lens:
174-
if transform_factory:
175-
self.transforms[seq_len] = transform_factory(max_seq_len=seq_len, patch_size=self.patch_size)
176-
else:
177-
self.transforms[seq_len] = None # No transform
178276
self.collate_fns[seq_len] = NaFlexCollator(seq_len)
277+
278+
for patch_idx, patch_size_tuple in enumerate(self.patch_sizes):
279+
# Pre-initialize patchifiers for each patch size (indexed by patch_idx)
280+
self.patchifiers.append(Patchify(
281+
patch_size=patch_size_tuple,
282+
flatten_patches=not self.variable_patch_size
283+
))
284+
285+
# Create transforms for each (seq_len, patch_idx) combination
286+
for seq_len in self.seq_lens:
287+
key = (seq_len, patch_idx)
288+
if transform_factory:
289+
self.transforms[key] = transform_factory(max_seq_len=seq_len, patch_size=patch_size_tuple)
290+
else:
291+
self.transforms[key] = None # No transform
292+
179293
self.mixup_fn = mixup_fn
180-
self.patchifier = Patchify(self.patch_size)
181294

182295
# --- Canonical Schedule Calculation (Done Once) ---
183296
self._canonical_batch_schedule: List[Tuple[int, int]] = []
@@ -363,25 +476,30 @@ def _prepare_epoch_batches(self, epoch: int):
363476
f"Indices remaining: {effective_samples_this_rank - scheduled_samples_count}."
364477
)
365478

366-
def set_epoch(self, epoch: int):
367-
"""Updates the epoch, regenerating the epoch-specific batches."""
479+
def set_epoch(self, epoch: int) -> None:
480+
"""Updates the epoch, regenerating the epoch-specific batches.
481+
482+
Args:
483+
epoch: New epoch number.
484+
"""
368485
# Only regenerate if the epoch actually changes
369486
if epoch != self.epoch:
370487
self.epoch = epoch
371488
self._prepare_epoch_batches(epoch)
372489

373490
def __len__(self) -> int:
374-
"""
375-
Returns the number of batches per **worker** for the current epoch.
376-
Calculated based on the fixed number of batches per rank divided by
377-
the number of workers.
491+
"""Returns the number of batches per worker for the current epoch.
492+
493+
Returns:
494+
Number of batches this worker will process.
378495
"""
379496
return self._num_batches_per_rank
380497

381498
def __iter__(self) -> Iterator[Tuple[Dict[str, torch.Tensor], torch.Tensor]]:
382-
"""
383-
Iterates through the pre-calculated batches for the current epoch,
384-
distributing them among workers.
499+
"""Iterates through pre-calculated batches for the current epoch.
500+
501+
Yields:
502+
Tuple of (input_dict, targets) for each batch.
385503
"""
386504
worker_info = torch.utils.data.get_worker_info()
387505
num_workers = worker_info.num_workers if worker_info else 1
@@ -394,10 +512,17 @@ def __iter__(self) -> Iterator[Tuple[Dict[str, torch.Tensor], torch.Tensor]]:
394512
if not indices: # Skip if a batch ended up with no indices (shouldn't happen often)
395513
continue
396514

397-
# Get the pre-initialized transform for this sequence length
398-
transform = self.transforms.get(seq_len)
515+
# Select patch size for this batch
516+
patch_idx = 0
517+
if self.variable_patch_size:
518+
# Use torch multinomial for weighted random choice
519+
patch_idx = torch.multinomial(torch.tensor(self.patch_size_probs), 1).item()
520+
521+
# Get the pre-initialized transform and patchifier using patch_idx
522+
transform_key = (seq_len, patch_idx)
523+
transform = self.transforms.get(transform_key)
524+
batch_patchifier = self.patchifiers[patch_idx]
399525

400-
batch_samples = []
401526
batch_imgs = []
402527
batch_targets = []
403528
for idx in indices:
@@ -426,7 +551,7 @@ def __iter__(self) -> Iterator[Tuple[Dict[str, torch.Tensor], torch.Tensor]]:
426551
if self.mixup_fn is not None:
427552
batch_imgs, batch_targets = self.mixup_fn(batch_imgs, batch_targets)
428553

429-
batch_imgs = [self.patchifier(img) for img in batch_imgs]
554+
batch_imgs = [batch_patchifier(img) for img in batch_imgs]
430555
batch_samples = list(zip(batch_imgs, batch_targets))
431556
if batch_samples: # Only yield if we successfully processed samples
432557
# Collate the processed samples into a batch

0 commit comments

Comments
 (0)