@@ -96,13 +96,13 @@ class Mixup:
9696 cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None.
9797 prob (float): probability of applying mixup or cutmix per batch or element
9898 switch_prob (float): probability of switching to cutmix instead of mixup when both are active
99- elementwise (bool ): apply mixup/cutmix params per batch element instead of per batch
99+ mode (str ): how to apply mixup/cutmix params ( per ' batch', 'pair' (pair of elements), 'elem' (element)
100100 correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders
101101 label_smoothing (float): apply label smoothing to the mixed target tensor
102102 num_classes (int): number of classes for target
103103 """
104104 def __init__ (self , mixup_alpha = 1. , cutmix_alpha = 0. , cutmix_minmax = None , prob = 1.0 , switch_prob = 0.5 ,
105- elementwise = False , correct_lam = True , label_smoothing = 0.1 , num_classes = 1000 ):
105+ mode = 'batch' , correct_lam = True , label_smoothing = 0.1 , num_classes = 1000 ):
106106 self .mixup_alpha = mixup_alpha
107107 self .cutmix_alpha = cutmix_alpha
108108 self .cutmix_minmax = cutmix_minmax
@@ -114,7 +114,7 @@ def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0
114114 self .switch_prob = switch_prob
115115 self .label_smoothing = label_smoothing
116116 self .num_classes = num_classes
117- self .elementwise = elementwise
117+ self .mode = mode
118118 self .correct_lam = correct_lam # correct lambda based on clipped area for cutmix
119119 self .mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop)
120120
@@ -173,6 +173,26 @@ def _mix_elem(self, x):
173173 x [i ] = x [i ] * lam + x_orig [j ] * (1 - lam )
174174 return torch .tensor (lam_batch , device = x .device , dtype = x .dtype ).unsqueeze (1 )
175175
176+ def _mix_pair (self , x ):
177+ batch_size = len (x )
178+ lam_batch , use_cutmix = self ._params_per_elem (batch_size // 2 )
179+ x_orig = x .clone () # need to keep an unmodified original for mixing source
180+ for i in range (batch_size // 2 ):
181+ j = batch_size - i - 1
182+ lam = lam_batch [i ]
183+ if lam != 1. :
184+ if use_cutmix [i ]:
185+ (yl , yh , xl , xh ), lam = cutmix_bbox_and_lam (
186+ x [i ].shape , lam , ratio_minmax = self .cutmix_minmax , correct_lam = self .correct_lam )
187+ x [i ][:, yl :yh , xl :xh ] = x_orig [j ][:, yl :yh , xl :xh ]
188+ x [j ][:, yl :yh , xl :xh ] = x_orig [i ][:, yl :yh , xl :xh ]
189+ lam_batch [i ] = lam
190+ else :
191+ x [i ] = x [i ] * lam + x_orig [j ] * (1 - lam )
192+ x [j ] = x [j ] * lam + x_orig [i ] * (1 - lam )
193+ lam_batch = np .concatenate ((lam_batch , lam_batch [::- 1 ]))
194+ return torch .tensor (lam_batch , device = x .device , dtype = x .dtype ).unsqueeze (1 )
195+
176196 def _mix_batch (self , x ):
177197 lam , use_cutmix = self ._params_per_batch ()
178198 if lam == 1. :
@@ -188,7 +208,12 @@ def _mix_batch(self, x):
188208
189209 def __call__ (self , x , target ):
190210 assert len (x ) % 2 == 0 , 'Batch size should be even when using this'
191- lam = self ._mix_elem (x ) if self .elementwise else self ._mix_batch (x )
211+ if self .mode == 'elem' :
212+ lam = self ._mix_elem (x )
213+ elif self .mode == 'pair' :
214+ lam = self ._mix_pair (x )
215+ else :
216+ lam = self ._mix_batch (x )
192217 target = mixup_target (target , self .num_classes , lam , self .label_smoothing )
193218 return x , target
194219
@@ -199,25 +224,57 @@ class FastCollateMixup(Mixup):
199224 A Mixup impl that's performed while collating the batches.
200225 """
201226
202- def _mix_elem_collate (self , output , batch ):
227+ def _mix_elem_collate (self , output , batch , half = False ):
203228 batch_size = len (batch )
204- lam_batch , use_cutmix = self ._params_per_elem (batch_size )
205- for i in range (batch_size ):
229+ num_elem = batch_size // 2 if half else batch_size
230+ assert len (output ) == num_elem
231+ lam_batch , use_cutmix = self ._params_per_elem (num_elem )
232+ for i in range (num_elem ):
206233 j = batch_size - i - 1
207234 lam = lam_batch [i ]
208235 mixed = batch [i ][0 ]
209236 if lam != 1. :
210237 if use_cutmix [i ]:
211- mixed = mixed .copy ()
238+ if not half :
239+ mixed = mixed .copy ()
212240 (yl , yh , xl , xh ), lam = cutmix_bbox_and_lam (
213241 output .shape , lam , ratio_minmax = self .cutmix_minmax , correct_lam = self .correct_lam )
214242 mixed [:, yl :yh , xl :xh ] = batch [j ][0 ][:, yl :yh , xl :xh ]
215243 lam_batch [i ] = lam
216244 else :
217245 mixed = mixed .astype (np .float32 ) * lam + batch [j ][0 ].astype (np .float32 ) * (1 - lam )
218- lam_batch [i ] = lam
219- np .round (mixed , out = mixed )
246+ np .rint (mixed , out = mixed )
220247 output [i ] += torch .from_numpy (mixed .astype (np .uint8 ))
248+ if half :
249+ lam_batch = np .concatenate ((lam_batch , np .ones (num_elem )))
250+ return torch .tensor (lam_batch ).unsqueeze (1 )
251+
252+ def _mix_pair_collate (self , output , batch ):
253+ batch_size = len (batch )
254+ lam_batch , use_cutmix = self ._params_per_elem (batch_size // 2 )
255+ for i in range (batch_size // 2 ):
256+ j = batch_size - i - 1
257+ lam = lam_batch [i ]
258+ mixed_i = batch [i ][0 ]
259+ mixed_j = batch [j ][0 ]
260+ assert 0 <= lam <= 1.0
261+ if lam < 1. :
262+ if use_cutmix [i ]:
263+ (yl , yh , xl , xh ), lam = cutmix_bbox_and_lam (
264+ output .shape , lam , ratio_minmax = self .cutmix_minmax , correct_lam = self .correct_lam )
265+ patch_i = mixed_i [:, yl :yh , xl :xh ].copy ()
266+ mixed_i [:, yl :yh , xl :xh ] = mixed_j [:, yl :yh , xl :xh ]
267+ mixed_j [:, yl :yh , xl :xh ] = patch_i
268+ lam_batch [i ] = lam
269+ else :
270+ mixed_temp = mixed_i .astype (np .float32 ) * lam + mixed_j .astype (np .float32 ) * (1 - lam )
271+ mixed_j = mixed_j .astype (np .float32 ) * lam + mixed_i .astype (np .float32 ) * (1 - lam )
272+ mixed_i = mixed_temp
273+ np .rint (mixed_j , out = mixed_j )
274+ np .rint (mixed_i , out = mixed_i )
275+ output [i ] += torch .from_numpy (mixed_i .astype (np .uint8 ))
276+ output [j ] += torch .from_numpy (mixed_j .astype (np .uint8 ))
277+ lam_batch = np .concatenate ((lam_batch , lam_batch [::- 1 ]))
221278 return torch .tensor (lam_batch ).unsqueeze (1 )
222279
223280 def _mix_batch_collate (self , output , batch ):
@@ -235,19 +292,25 @@ def _mix_batch_collate(self, output, batch):
235292 mixed [:, yl :yh , xl :xh ] = batch [j ][0 ][:, yl :yh , xl :xh ]
236293 else :
237294 mixed = mixed .astype (np .float32 ) * lam + batch [j ][0 ].astype (np .float32 ) * (1 - lam )
238- np .round (mixed , out = mixed )
295+ np .rint (mixed , out = mixed )
239296 output [i ] += torch .from_numpy (mixed .astype (np .uint8 ))
240297 return lam
241298
242299 def __call__ (self , batch , _ = None ):
243300 batch_size = len (batch )
244301 assert batch_size % 2 == 0 , 'Batch size should be even when using this'
302+ half = 'half' in self .mode
303+ if half :
304+ batch_size //= 2
245305 output = torch .zeros ((batch_size , * batch [0 ][0 ].shape ), dtype = torch .uint8 )
246- if self .elementwise :
247- lam = self ._mix_elem_collate (output , batch )
306+ if self .mode == 'elem' or self .mode == 'half' :
307+ lam = self ._mix_elem_collate (output , batch , half = half )
308+ elif self .mode == 'pair' :
309+ lam = self ._mix_pair_collate (output , batch )
248310 else :
249311 lam = self ._mix_batch_collate (output , batch )
250312 target = torch .tensor ([b [1 ] for b in batch ], dtype = torch .int64 )
251313 target = mixup_target (target , self .num_classes , lam , self .label_smoothing , device = 'cpu' )
314+ target = target [:batch_size ]
252315 return output , target
253316
0 commit comments