2525
2626
2727from .naflex_transforms import Patchify , patchify_image
28+ from ..layers import to_2tuple
2829
2930
3031def calculate_naflex_batch_size (
3132 tokens_per_batch : int ,
3233 seq_len : int ,
3334 max_size : Optional [int ] = None ,
3435 divisor : int = 1 ,
35- rounding : str = 'floor' ,
36- ):
37- """Calculate batch size based on sequence length with divisibility constraints."""
36+ rounding : str = 'floor' ,
37+ ) -> int :
38+ """Calculate batch size based on sequence length with divisibility constraints.
39+
40+ Args:
41+ tokens_per_batch: Target number of tokens per batch.
42+ seq_len: Sequence length for this batch.
43+ max_size: Optional maximum batch size.
44+ divisor: Ensure batch size is divisible by this value.
45+ rounding: Rounding method ('floor', 'ceil', 'round').
46+
47+ Returns:
48+ Calculated batch size.
49+ """
3850 # Calculate raw batch size based on sequence length
3951 raw_batch_size = tokens_per_batch / seq_len
4052
@@ -64,14 +76,20 @@ class NaFlexCollator:
6476
6577 def __init__ (
6678 self ,
67- max_seq_len = None ,
68- ):
69- self . max_seq_len = max_seq_len or 576 # Default ViT-B/16 sequence length (577 = 24*24)
79+ max_seq_len : Optional [ int ] = None ,
80+ ) -> None :
81+ """Initialize NaFlexCollator.
7082
71- def __call__ (self , batch ):
83+ Args:
84+ max_seq_len: Maximum sequence length for batching.
7285 """
86+ self .max_seq_len = max_seq_len or 576 # Default ViT-B/16 sequence length (577 = 24*24)
87+
88+ def __call__ (self , batch : List [Tuple [Dict [str , torch .Tensor ], Union [int , torch .Tensor ]]]) -> Tuple [Dict [str , torch .Tensor ], torch .Tensor ]:
89+ """Collate batch of NaFlex samples.
90+
7391 Args:
74- batch: List of tuples (patch_dict, target)
92+ batch: List of tuples (patch_dict, target).
7593
7694 Returns:
7795 A tuple of (input_dict, targets) where input_dict contains:
@@ -99,11 +117,20 @@ def __call__(self, batch):
99117 # Find the maximum number of patches in this batch
100118 max_patches = max (item ['patches' ].shape [0 ] for item in patch_dicts )
101119
102- # Get patch dimensionality
103- patch_dim = patch_dicts [0 ]['patches' ].shape [1 ]
120+ # Check if patches are flattened or unflattened
121+ patches_tensor = patch_dicts [0 ]['patches' ]
122+ is_unflattened = patches_tensor .ndim == 4 # [N, Ph, Pw, C]
123+
124+ if is_unflattened :
125+ # Patches are [N, Ph, Pw, C] - variable patch size mode
126+ _ , ph , pw , c = patches_tensor .shape
127+ patches = torch .zeros ((batch_size , max_patches , ph , pw , c ), dtype = torch .float32 )
128+ else :
129+ # Patches are [N, P*P*C] - normal mode
130+ patch_dim = patches_tensor .shape [1 ]
131+ patches = torch .zeros ((batch_size , max_patches , patch_dim ), dtype = torch .float32 )
104132
105- # Prepare tensors for the batch
106- patches = torch .zeros ((batch_size , max_patches , patch_dim ), dtype = torch .float32 )
133+ # Prepare other tensors
107134 patch_coord = torch .zeros ((batch_size , max_patches , 2 ), dtype = torch .int64 ) # [B, N, 2] for (y, x)
108135 patch_valid = torch .zeros ((batch_size , max_patches ), dtype = torch .bool )
109136
@@ -115,12 +142,58 @@ def __call__(self, batch):
115142 patch_coord [i , :num_patches ] = patch_dict ['patch_coord' ][:num_patches ]
116143 patch_valid [i , :num_patches ] = patch_dict ['patch_valid' ][:num_patches ]
117144
118- return {
145+ result = {
119146 'patches' : patches ,
120147 'patch_coord' : patch_coord ,
121148 'patch_valid' : patch_valid ,
122149 'seq_len' : max_patches ,
123- }, targets
150+ }
151+
152+ return result , targets
153+
154+
155+ def _resolve_patch_cfg (
156+ patch_size : Optional [Union [int , Tuple [int , int ]]],
157+ patch_size_choices : Optional [List [int ]],
158+ patch_size_choice_probs : Optional [List [float ]],
159+ ) -> Tuple [List [Tuple [int , int ]], List [float ], bool ]:
160+ """Resolve patch size configuration.
161+
162+ Args:
163+ patch_size: Single patch size to use.
164+ patch_size_choices: List of patch sizes to choose from.
165+ patch_size_choice_probs: Probabilities for each patch size choice.
166+
167+ Returns:
168+ Tuple of (sizes, probs, variable) where sizes is list of patch size tuples,
169+ probs is list of probabilities, and variable indicates if patch size varies.
170+ """
171+ # If both are None, default to patch_size=16
172+ if patch_size is None and patch_size_choices is None :
173+ patch_size = 16
174+
175+ if (patch_size is None ) == (patch_size_choices is None ):
176+ raise ValueError (
177+ "Specify exactly one of `patch_size` or `patch_size_choices`."
178+ )
179+
180+ if patch_size is not None :
181+ sizes = [to_2tuple (patch_size )]
182+ probs = [1.0 ]
183+ variable = False
184+ else :
185+ sizes = [to_2tuple (p ) for p in patch_size_choices ]
186+ if patch_size_choice_probs is None :
187+ probs = [1.0 / len (sizes )] * len (sizes )
188+ else :
189+ if len (patch_size_choice_probs ) != len (sizes ):
190+ raise ValueError ("`patch_size_choice_probs` length mismatch." )
191+ s = float (sum (patch_size_choice_probs ))
192+ if s <= 0 :
193+ raise ValueError ("`patch_size_choice_probs` sum to zero." )
194+ probs = [p / s for p in patch_size_choice_probs ]
195+ variable = True
196+ return sizes , probs , variable
124197
125198
126199class NaFlexMapDatasetWrapper (IterableDataset ):
@@ -138,9 +211,11 @@ class NaFlexMapDatasetWrapper(IterableDataset):
138211 def __init__ (
139212 self ,
140213 base_dataset : Dataset ,
141- patch_size : Union [int , Tuple [int , int ]] = 16 ,
142- seq_lens : List [int ] = (128 , 256 , 576 , 784 , 1024 ),
143- max_tokens_per_batch : int = 4096 * 4 , # Example: 16k tokens
214+ patch_size : Optional [Union [int , Tuple [int , int ]]] = None ,
215+ patch_size_choices : Optional [List [int ]] = None ,
216+ patch_size_choice_probs : Optional [List [float ]] = None ,
217+ seq_lens : Tuple [int , ...] = (128 , 256 , 576 , 784 , 1024 ),
218+ max_tokens_per_batch : int = 4096 * 4 ,
144219 transform_factory : Optional [Callable ] = None ,
145220 mixup_fn : Optional [Callable ] = None ,
146221 seed : int = 42 ,
@@ -149,14 +224,32 @@ def __init__(
149224 rank : int = 0 ,
150225 world_size : int = 1 ,
151226 epoch : int = 0 ,
152- batch_divisor : int = 8 , # Ensure batch size is divisible by this
153- ):
227+ batch_divisor : int = 8 ,
228+ ) -> None :
229+ """Initialize NaFlexMapDatasetWrapper.
230+
231+ Args:
232+ base_dataset: Map-style dataset to wrap.
233+ patch_size: Single patch size to use.
234+ patch_size_choices: List of patch sizes to randomly select from.
235+ patch_size_choice_probs: Probabilities for each patch size.
236+ seq_lens: Sequence lengths to use for batching.
237+ max_tokens_per_batch: Target tokens per batch.
238+ transform_factory: Factory function for creating transforms.
239+ mixup_fn: Optional mixup function.
240+ seed: Random seed.
241+ shuffle: Whether to shuffle data.
242+ distributed: Whether using distributed training.
243+ rank: Process rank for distributed training.
244+ world_size: Total number of processes.
245+ epoch: Starting epoch.
246+ batch_divisor: Ensure batch size is divisible by this.
247+ """
154248 super ().__init__ ()
155249 if not hasattr (base_dataset , '__len__' ) or not hasattr (base_dataset , '__getitem__' ):
156250 raise TypeError ("base_dataset must be a map-style dataset (implement __len__ and __getitem__)" )
157251
158252 self .base_dataset = base_dataset
159- self .patch_size = patch_size if isinstance (patch_size , tuple ) else (patch_size , patch_size )
160253 self .seq_lens = sorted (list (set (seq_lens ))) # Ensure unique and sorted
161254 self .max_tokens_per_batch = max_tokens_per_batch
162255 self .seed = seed
@@ -167,17 +260,37 @@ def __init__(
167260 self .epoch = epoch
168261 self .batch_divisor = batch_divisor
169262
170- # Pre-initialize transforms and collate fns for each sequence length
171- self .transforms : Dict [int , Optional [Callable ]] = {}
263+ # Resolve patch size configuration
264+ self .patch_sizes , self .patch_size_probs , self .variable_patch_size = _resolve_patch_cfg (
265+ patch_size ,
266+ patch_size_choices ,
267+ patch_size_choice_probs
268+ )
269+
270+ # Pre-initialize transforms and collate fns for each (seq_len, patch_idx) combination
271+ self .transforms : Dict [Tuple [int , int ], Optional [Callable ]] = {}
172272 self .collate_fns : Dict [int , Callable ] = {}
273+ self .patchifiers : List [Callable ] = []
274+
173275 for seq_len in self .seq_lens :
174- if transform_factory :
175- self .transforms [seq_len ] = transform_factory (max_seq_len = seq_len , patch_size = self .patch_size )
176- else :
177- self .transforms [seq_len ] = None # No transform
178276 self .collate_fns [seq_len ] = NaFlexCollator (seq_len )
277+
278+ for patch_idx , patch_size_tuple in enumerate (self .patch_sizes ):
279+ # Pre-initialize patchifiers for each patch size (indexed by patch_idx)
280+ self .patchifiers .append (Patchify (
281+ patch_size = patch_size_tuple ,
282+ flatten_patches = not self .variable_patch_size
283+ ))
284+
285+ # Create transforms for each (seq_len, patch_idx) combination
286+ for seq_len in self .seq_lens :
287+ key = (seq_len , patch_idx )
288+ if transform_factory :
289+ self .transforms [key ] = transform_factory (max_seq_len = seq_len , patch_size = patch_size_tuple )
290+ else :
291+ self .transforms [key ] = None # No transform
292+
179293 self .mixup_fn = mixup_fn
180- self .patchifier = Patchify (self .patch_size )
181294
182295 # --- Canonical Schedule Calculation (Done Once) ---
183296 self ._canonical_batch_schedule : List [Tuple [int , int ]] = []
@@ -363,25 +476,30 @@ def _prepare_epoch_batches(self, epoch: int):
363476 f"Indices remaining: { effective_samples_this_rank - scheduled_samples_count } ."
364477 )
365478
366- def set_epoch (self , epoch : int ):
367- """Updates the epoch, regenerating the epoch-specific batches."""
479+ def set_epoch (self , epoch : int ) -> None :
480+ """Updates the epoch, regenerating the epoch-specific batches.
481+
482+ Args:
483+ epoch: New epoch number.
484+ """
368485 # Only regenerate if the epoch actually changes
369486 if epoch != self .epoch :
370487 self .epoch = epoch
371488 self ._prepare_epoch_batches (epoch )
372489
373490 def __len__ (self ) -> int :
374- """
375- Returns the number of batches per **worker** for the current epoch.
376- Calculated based on the fixed number of batches per rank divided by
377- the number of workers .
491+ """Returns the number of batches per worker for the current epoch.
492+
493+ Returns:
494+ Number of batches this worker will process .
378495 """
379496 return self ._num_batches_per_rank
380497
381498 def __iter__ (self ) -> Iterator [Tuple [Dict [str , torch .Tensor ], torch .Tensor ]]:
382- """
383- Iterates through the pre-calculated batches for the current epoch,
384- distributing them among workers.
499+ """Iterates through pre-calculated batches for the current epoch.
500+
501+ Yields:
502+ Tuple of (input_dict, targets) for each batch.
385503 """
386504 worker_info = torch .utils .data .get_worker_info ()
387505 num_workers = worker_info .num_workers if worker_info else 1
@@ -394,10 +512,17 @@ def __iter__(self) -> Iterator[Tuple[Dict[str, torch.Tensor], torch.Tensor]]:
394512 if not indices : # Skip if a batch ended up with no indices (shouldn't happen often)
395513 continue
396514
397- # Get the pre-initialized transform for this sequence length
398- transform = self .transforms .get (seq_len )
515+ # Select patch size for this batch
516+ patch_idx = 0
517+ if self .variable_patch_size :
518+ # Use torch multinomial for weighted random choice
519+ patch_idx = torch .multinomial (torch .tensor (self .patch_size_probs ), 1 ).item ()
520+
521+ # Get the pre-initialized transform and patchifier using patch_idx
522+ transform_key = (seq_len , patch_idx )
523+ transform = self .transforms .get (transform_key )
524+ batch_patchifier = self .patchifiers [patch_idx ]
399525
400- batch_samples = []
401526 batch_imgs = []
402527 batch_targets = []
403528 for idx in indices :
@@ -426,7 +551,7 @@ def __iter__(self) -> Iterator[Tuple[Dict[str, torch.Tensor], torch.Tensor]]:
426551 if self .mixup_fn is not None :
427552 batch_imgs , batch_targets = self .mixup_fn (batch_imgs , batch_targets )
428553
429- batch_imgs = [self . patchifier (img ) for img in batch_imgs ]
554+ batch_imgs = [batch_patchifier (img ) for img in batch_imgs ]
430555 batch_samples = list (zip (batch_imgs , batch_targets ))
431556 if batch_samples : # Only yield if we successfully processed samples
432557 # Collate the processed samples into a batch
0 commit comments