Skip to content

Commit 4ff865c

Browse files
committed
A bit of docstring and comment consistency cleanup, remove some debug code
1 parent dac2ec6 commit 4ff865c

File tree

6 files changed

+44
-18
lines changed

6 files changed

+44
-18
lines changed

timm/data/naflex_dataset.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
1-
"""
2-
Dynamic Sequence Length Datasets for Variable Resolution Image Processing
1+
""" Dynamic Sequence Length Datasets for Variable Resolution Image Processing
32
43
Implements two dataset wrappers:
5-
1. DynamicSeqMapDataset - Map-style dataset that returns batches with variable sequence lengths
6-
2. DynamicSeqIterDataset - Iterable dataset that yields batches with variable sequence lengths
4+
1. NaFlexMapDatasetWrapper - Map-style dataset that returns batches with variable sequence lengths
5+
TODO: 2. NaFlexIterableDatasetWrapper - Iterable dataset that yields batches with variable sequence lengths
76
87
Both support:
98
- Pre-initialized transforms for efficiency
109
- Distributed training
1110
- Multiple workers
1211
- Variable batch sizes based on sequence length
12+
13+
Hacked together by / Copyright 2025, Ross Wightman, Hugging Face
1314
"""
1415

1516
import math
@@ -20,12 +21,10 @@
2021

2122
import torch
2223
from torch.utils.data import Dataset, IterableDataset, DataLoader
23-
from torchvision import transforms
2424
from PIL import Image
2525

26-
27-
from .naflex_transforms import Patchify, patchify_image
28-
from ..layers import to_2tuple
26+
from .naflex_transforms import Patchify
27+
from timm.layers import to_2tuple
2928

3029

3130
def calculate_naflex_batch_size(
@@ -203,7 +202,7 @@ class NaFlexMapDatasetWrapper(IterableDataset):
203202
Yields batches with variable sequence lengths. It calculates a canonical
204203
batch schedule (sequence length, batch size pairs) once based on the
205204
total dataset size (padded for distribution). Each epoch, it shuffles
206-
the *order* of this canonical schedule and the dataset indices.
205+
the order of this canonical schedule and the dataset indices.
207206
This ensures a consistent number of batches and samples per epoch
208207
across all ranks. Handles distributed training and multiple workers.
209208
"""
@@ -292,13 +291,13 @@ def __init__(
292291

293292
self.mixup_fn = mixup_fn
294293

295-
# --- Canonical Schedule Calculation (Done Once) ---
294+
# Canonical Schedule Calculation (Done Once)
296295
self._canonical_batch_schedule: List[Tuple[int, int]] = []
297296
self._num_batches_per_rank: int = 0
298297
self._padded_samples_per_rank: int = 0
299298
self._create_canonical_schedule() # Calculate schedule based on padded size
300299

301-
# --- Per-Epoch State ---
300+
# Per-Epoch State
302301
# Stores (seq_len, list_of_indices) for the current epoch, specific to this rank
303302
self._epoch_batches: List[Tuple[int, List[int]]] = []
304303
self._prepare_epoch_batches(self.epoch) # setup for initial epoch
@@ -420,7 +419,6 @@ def _prepare_epoch_batches(self, epoch: int):
420419
if len(indices_for_ranks) != padded_total_len:
421420
raise RuntimeError(f"Internal Error: Padded index list length {len(indices_for_ranks)} does not match expected {padded_total_len}")
422421

423-
424422
# 3. Select indices for the current rank
425423
if self.distributed and self.world_size > 1:
426424
indices_this_rank = indices_for_ranks[self.rank::self.world_size]

timm/data/naflex_loader.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,14 @@
1+
"""NaFlex data loader for dynamic sequence length training.
2+
3+
This module provides a specialized data loader for Vision Transformer models that supports:
4+
- Dynamic sequence length sampling during training for improved efficiency
5+
- Variable patch size training with probabilistic selection
6+
- Patch-level random erasing augmentation
7+
- Efficient GPU prefetching with normalization
8+
9+
Hacked together by / Copyright 2025, Ross Wightman, Hugging Face
10+
"""
11+
112
import math
213
from contextlib import suppress
314
from functools import partial

timm/data/naflex_mixup.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
all augmentation hyper‑parameters in one place, making it easy to plug into
1212
different dataset wrappers.
1313
14+
Hacked together by / Copyright 2025, Ross Wightman, Hugging Face
1415
"""
1516
import math
1617
import random
@@ -113,7 +114,6 @@ def mix_batch_variable_size(
113114

114115
corrected_lam = 1.0 - cut_area / float(dest_area)
115116
lam_list[i] = corrected_lam
116-
#print(i, 'Doing cutmix', yl_i, xl_i, yl_j, xl_j, ch, cw, lam_raw, corrected_lam)
117117
else:
118118
# Mixup: blend the entire overlap region
119119
patch_i = xi[:, top_i:top_i + oh, left_i:left_i + ow]
@@ -125,7 +125,6 @@ def mix_batch_variable_size(
125125

126126
corrected_lam = (dest_area - overlap_area) / dest_area + lam_raw * overlap_area / dest_area
127127
lam_list[i] = corrected_lam
128-
#print(i, 'Doing mixup', top_i, left_i, top_j, left_j, (oh, ow), (hi, wi), (hj, wj), lam_raw, corrected_lam)
129128

130129
return mixed_imgs, lam_list, pair_to
131130

timm/data/naflex_random_erasing.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,16 @@
1+
"""Patch-level random erasing augmentation for NaFlex Vision Transformers.
2+
3+
This module implements random erasing specifically designed for patchified images,
4+
operating at the patch granularity rather than pixel level. It supports two modes:
5+
- 'patch': Randomly erases individual patches (speckle-like noise)
6+
- 'region': Erases contiguous rectangular regions of patches (similar to original RandomErasing)
7+
8+
The implementation is coordinate-aware, respecting valid patch boundaries and supporting
9+
variable patch sizes in NaFlex training.
10+
11+
Hacked together by / Copyright 2025, Ross Wightman, Hugging Face
12+
"""
13+
114
import random
215
import math
316
from typing import Optional, Union, Tuple

timm/data/naflex_transforms.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
- FlexiViT: https://arxiv.org/abs/2212.08013
66
77
Enables variable resolution/aspect ratio image handling with efficient patching.
8+
9+
Hacked together by / Copyright 2025, Ross Wightman, Hugging Face
810
"""
911

1012
import math

timm/models/naflexvit.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
1-
""" Vision Transformer (New)
1+
""" NaFlex Vision Transformer
22
33
An improved version of the Vision Transformer with:
44
1. Encapsulated embedding and position encoding in a single module
55
2. Support for linear patch embedding on pre-patchified inputs
6-
3. Support for NaFlex functionality (NaViT + FlexiViT)
6+
3. Support for NaFlex variable aspect, variable resolution
7+
4. Support for FlexiViT variable patch size
8+
5. Support for NaViT fractional/factorized position embedding
79
8-
Based on:
10+
Based on ideas from:
911
- Original Vision Transformer: https://arxiv.org/abs/2010.11929
1012
- FlexiViT: https://arxiv.org/abs/2212.08013
1113
- NaViT: https://arxiv.org/abs/2307.06304
14+
- NaFlex (SigLip-2): https://arxiv.org/abs/2502.14786
1215
13-
Copyright 2025
16+
Hacked together by / Copyright 2025, Ross Wightman, Hugging Face
1417
"""
1518

1619
import logging

0 commit comments

Comments
 (0)