1- """ Mixup
2- Paper: `mixup: Beyond Empirical Risk Minimization` - https://arxiv.org/abs/1710.09412
1+ """ Mixup and Cutmix
2+
3+ Papers:
4+ mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412)
5+
6+ CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features (https://arxiv.org/abs/1905.04899)
7+
8+ Code Reference:
9+ CutMix: https://github.com/clovaai/CutMix-PyTorch
310
411Hacked together by / Copyright 2020 Ross Wightman
512"""
@@ -17,40 +24,230 @@ def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'):
1724 on_value = 1. - smoothing + off_value
1825 y1 = one_hot (target , num_classes , on_value = on_value , off_value = off_value , device = device )
1926 y2 = one_hot (target .flip (0 ), num_classes , on_value = on_value , off_value = off_value , device = device )
20- return lam * y1 + (1. - lam )* y2
27+ return y1 * lam + y2 * (1. - lam )
28+
29+
30+ def rand_bbox (img_shape , lam , margin = 0. , count = None ):
31+ """ Standard CutMix bounding-box
32+ Generates a random square bbox based on lambda value. This impl includes
33+ support for enforcing a border margin as percent of bbox dimensions.
34+
35+ Args:
36+ img_shape (tuple): Image shape as tuple
37+ lam (float): Cutmix lambda value
38+ margin (float): Percentage of bbox dimension to enforce as margin (reduce amount of box outside image)
39+ count (int): Number of bbox to generate
40+ """
41+ ratio = np .sqrt (1 - lam )
42+ img_h , img_w = img_shape [- 2 :]
43+ cut_h , cut_w = int (img_h * ratio ), int (img_w * ratio )
44+ margin_y , margin_x = int (margin * cut_h ), int (margin * cut_w )
45+ cy = np .random .randint (0 + margin_y , img_h - margin_y , size = count )
46+ cx = np .random .randint (0 + margin_x , img_w - margin_x , size = count )
47+ yl = np .clip (cy - cut_h // 2 , 0 , img_h )
48+ yh = np .clip (cy + cut_h // 2 , 0 , img_h )
49+ xl = np .clip (cx - cut_w // 2 , 0 , img_w )
50+ xh = np .clip (cx + cut_w // 2 , 0 , img_w )
51+ return yl , yh , xl , xh
52+
53+
54+ def rand_bbox_minmax (img_shape , minmax , count = None ):
55+ """ Min-Max CutMix bounding-box
56+ Inspired by Darknet cutmix impl, generates a random rectangular bbox
57+ based on min/max percent values applied to each dimension of the input image.
2158
59+ Typical defaults for minmax are usually in the .2-.3 for min and .8-.9 range for max.
2260
23- def mixup_batch (input , target , alpha = 0.2 , num_classes = 1000 , smoothing = 0.1 , disable = False ):
24- lam = 1.
25- if not disable :
26- lam = np .random .beta (alpha , alpha )
27- input = input .mul (lam ).add_ (1 - lam , input .flip (0 ))
28- target = mixup_target (target , num_classes , lam , smoothing )
29- return input , target
61+ Args:
62+ img_shape (tuple): Image shape as tuple
63+ minmax (tuple or list): Min and max bbox ratios (as percent of image size)
64+ count (int): Number of bbox to generate
65+ """
66+ assert len (minmax ) == 2
67+ img_h , img_w = img_shape [- 2 :]
68+ cut_h = np .random .randint (int (img_h * minmax [0 ]), int (img_h * minmax [1 ]), size = count )
69+ cut_w = np .random .randint (int (img_w * minmax [0 ]), int (img_w * minmax [1 ]), size = count )
70+ yl = np .random .randint (0 , img_h - cut_h , size = count )
71+ xl = np .random .randint (0 , img_w - cut_w , size = count )
72+ yu = yl + cut_h
73+ xu = xl + cut_w
74+ return yl , yu , xl , xu
3075
3176
32- class FastCollateMixup :
77+ def cutmix_bbox_and_lam (img_shape , lam , ratio_minmax = None , correct_lam = True , count = None ):
78+ """ Generate bbox and apply lambda correction.
79+ """
80+ if ratio_minmax is not None :
81+ yl , yu , xl , xu = rand_bbox_minmax (img_shape , ratio_minmax , count = count )
82+ else :
83+ yl , yu , xl , xu = rand_bbox (img_shape , lam , count = count )
84+ if correct_lam or ratio_minmax is not None :
85+ bbox_area = (yu - yl ) * (xu - xl )
86+ lam = 1. - bbox_area / float (img_shape [- 2 ] * img_shape [- 1 ])
87+ return (yl , yu , xl , xu ), lam
3388
34- def __init__ (self , mixup_alpha = 1. , label_smoothing = 0.1 , num_classes = 1000 ):
89+
90+ class Mixup :
91+ """ Mixup/Cutmix that applies different params to each element or whole batch
92+
93+ Args:
94+ mixup_alpha (float): mixup alpha value, mixup is active if > 0.
95+ cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0.
96+ cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None.
97+ prob (float): probability of applying mixup or cutmix per batch or element
98+ switch_prob (float): probability of switching to cutmix instead of mixup when both are active
99+ elementwise (bool): apply mixup/cutmix params per batch element instead of per batch
100+ correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders
101+ label_smoothing (float): apply label smoothing to the mixed target tensor
102+ num_classes (int): number of classes for target
103+ """
104+ def __init__ (self , mixup_alpha = 1. , cutmix_alpha = 0. , cutmix_minmax = None , prob = 1.0 , switch_prob = 0.5 ,
105+ elementwise = False , correct_lam = True , label_smoothing = 0.1 , num_classes = 1000 ):
35106 self .mixup_alpha = mixup_alpha
107+ self .cutmix_alpha = cutmix_alpha
108+ self .cutmix_minmax = cutmix_minmax
109+ if self .cutmix_minmax is not None :
110+ assert len (self .cutmix_minmax ) == 2
111+ # force cutmix alpha == 1.0 when minmax active to keep logic simple & safe
112+ self .cutmix_alpha = 1.0
113+ self .mix_prob = prob
114+ self .switch_prob = switch_prob
36115 self .label_smoothing = label_smoothing
37116 self .num_classes = num_classes
38- self .mixup_enabled = True
117+ self .elementwise = elementwise
118+ self .correct_lam = correct_lam # correct lambda based on clipped area for cutmix
119+ self .mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop)
39120
40- def __call__ (self , batch ):
41- batch_size = len ( batch )
42- lam = 1.
121+ def _params_per_elem (self , batch_size ):
122+ lam = np . ones ( batch_size , dtype = np . float32 )
123+ use_cutmix = np . zeros ( batch_size , dtype = np . bool )
43124 if self .mixup_enabled :
44- lam = np .random .beta (self .mixup_alpha , self .mixup_alpha )
125+ if self .mixup_alpha > 0. and self .cutmix_alpha > 0. :
126+ use_cutmix = np .random .rand (batch_size ) < self .switch_prob
127+ lam_mix = np .where (
128+ use_cutmix ,
129+ np .random .beta (self .cutmix_alpha , self .cutmix_alpha , size = batch_size ),
130+ np .random .beta (self .mixup_alpha , self .mixup_alpha , size = batch_size ))
131+ elif self .mixup_alpha > 0. :
132+ lam_mix = np .random .beta (self .mixup_alpha , self .mixup_alpha , size = batch_size )
133+ elif self .cutmix_alpha > 0. :
134+ use_cutmix = np .ones (batch_size , dtype = np .bool )
135+ lam_mix = np .random .beta (self .cutmix_alpha , self .cutmix_alpha , size = batch_size )
136+ else :
137+ assert False , "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
138+ lam = np .where (np .random .rand (batch_size ) < self .mix_prob , lam_mix .astype (np .float32 ), lam )
139+ return lam , use_cutmix
45140
46- target = torch .tensor ([b [1 ] for b in batch ], dtype = torch .int64 )
47- target = mixup_target (target , self .num_classes , lam , self .label_smoothing , device = 'cpu' )
141+ def _params_per_batch (self ):
142+ lam = 1.
143+ use_cutmix = False
144+ if self .mixup_enabled and np .random .rand () < self .mix_prob :
145+ if self .mixup_alpha > 0. and self .cutmix_alpha > 0. :
146+ use_cutmix = np .random .rand () < self .switch_prob
147+ lam_mix = np .random .beta (self .cutmix_alpha , self .cutmix_alpha ) if use_cutmix else \
148+ np .random .beta (self .mixup_alpha , self .mixup_alpha )
149+ elif self .mixup_alpha > 0. :
150+ lam_mix = np .random .beta (self .mixup_alpha , self .mixup_alpha )
151+ elif self .cutmix_alpha > 0. :
152+ use_cutmix = True
153+ lam_mix = np .random .beta (self .cutmix_alpha , self .cutmix_alpha )
154+ else :
155+ assert False , "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
156+ lam = float (lam_mix )
157+ return lam , use_cutmix
48158
49- tensor = torch .zeros ((batch_size , * batch [0 ][0 ].shape ), dtype = torch .uint8 )
159+ def _mix_elem (self , x ):
160+ batch_size = len (x )
161+ lam_batch , use_cutmix = self ._params_per_elem (batch_size )
162+ x_orig = x .clone () # need to keep an unmodified original for mixing source
50163 for i in range (batch_size ):
51- mixed = batch [i ][0 ].astype (np .float32 ) * lam + \
52- batch [batch_size - i - 1 ][0 ].astype (np .float32 ) * (1 - lam )
53- np .round (mixed , out = mixed )
54- tensor [i ] += torch .from_numpy (mixed .astype (np .uint8 ))
164+ j = batch_size - i - 1
165+ lam = lam_batch [i ]
166+ if lam != 1. :
167+ if use_cutmix [i ]:
168+ (yl , yh , xl , xh ), lam = cutmix_bbox_and_lam (
169+ x [i ].shape , lam , ratio_minmax = self .cutmix_minmax , correct_lam = self .correct_lam )
170+ x [i ][:, yl :yh , xl :xh ] = x_orig [j ][:, yl :yh , xl :xh ]
171+ lam_batch [i ] = lam
172+ else :
173+ x [i ] = x [i ] * lam + x_orig [j ] * (1 - lam )
174+ return torch .tensor (lam_batch , device = x .device , dtype = x .dtype ).unsqueeze (1 )
175+
176+ def _mix_batch (self , x ):
177+ lam , use_cutmix = self ._params_per_batch ()
178+ if lam == 1. :
179+ return 1.
180+ if use_cutmix :
181+ (yl , yh , xl , xh ), lam = cutmix_bbox_and_lam (
182+ x .shape , lam , ratio_minmax = self .cutmix_minmax , correct_lam = self .correct_lam )
183+ x [:, :, yl :yh , xl :xh ] = x .flip (0 )[:, :, yl :yh , xl :xh ]
184+ else :
185+ x_flipped = x .flip (0 ).mul_ (1. - lam )
186+ x .mul_ (lam ).add_ (x_flipped )
187+ return lam
188+
189+ def __call__ (self , x , target ):
190+ assert len (x ) % 2 == 0 , 'Batch size should be even when using this'
191+ lam = self ._mix_elem (x ) if self .elementwise else self ._mix_batch (x )
192+ target = mixup_target (target , self .num_classes , lam , self .label_smoothing )
193+ return x , target
194+
195+
196+ class FastCollateMixup (Mixup ):
197+ """ Fast Collate w/ Mixup/Cutmix that applies different params to each element or whole batch
198+
199+ A Mixup impl that's performed while collating the batches.
200+ """
201+
202+ def _mix_elem_collate (self , output , batch ):
203+ batch_size = len (batch )
204+ lam_batch , use_cutmix = self ._params_per_elem (batch_size )
205+ for i in range (batch_size ):
206+ j = batch_size - i - 1
207+ lam = lam_batch [i ]
208+ mixed = batch [i ][0 ]
209+ if lam != 1. :
210+ if use_cutmix [i ]:
211+ mixed = mixed .copy ()
212+ (yl , yh , xl , xh ), lam = cutmix_bbox_and_lam (
213+ output .shape , lam , ratio_minmax = self .cutmix_minmax , correct_lam = self .correct_lam )
214+ mixed [:, yl :yh , xl :xh ] = batch [j ][0 ][:, yl :yh , xl :xh ]
215+ lam_batch [i ] = lam
216+ else :
217+ mixed = mixed .astype (np .float32 ) * lam + batch [j ][0 ].astype (np .float32 ) * (1 - lam )
218+ lam_batch [i ] = lam
219+ np .round (mixed , out = mixed )
220+ output [i ] += torch .from_numpy (mixed .astype (np .uint8 ))
221+ return torch .tensor (lam_batch ).unsqueeze (1 )
222+
223+ def _mix_batch_collate (self , output , batch ):
224+ batch_size = len (batch )
225+ lam , use_cutmix = self ._params_per_batch ()
226+ if use_cutmix :
227+ (yl , yh , xl , xh ), lam = cutmix_bbox_and_lam (
228+ output .shape , lam , ratio_minmax = self .cutmix_minmax , correct_lam = self .correct_lam )
229+ for i in range (batch_size ):
230+ j = batch_size - i - 1
231+ mixed = batch [i ][0 ]
232+ if lam != 1. :
233+ if use_cutmix :
234+ mixed = mixed .copy () # don't want to modify the original while iterating
235+ mixed [:, yl :yh , xl :xh ] = batch [j ][0 ][:, yl :yh , xl :xh ]
236+ else :
237+ mixed = mixed .astype (np .float32 ) * lam + batch [j ][0 ].astype (np .float32 ) * (1 - lam )
238+ np .round (mixed , out = mixed )
239+ output [i ] += torch .from_numpy (mixed .astype (np .uint8 ))
240+ return lam
241+
242+ def __call__ (self , batch , _ = None ):
243+ batch_size = len (batch )
244+ assert batch_size % 2 == 0 , 'Batch size should be even when using this'
245+ output = torch .zeros ((batch_size , * batch [0 ][0 ].shape ), dtype = torch .uint8 )
246+ if self .elementwise :
247+ lam = self ._mix_elem_collate (output , batch )
248+ else :
249+ lam = self ._mix_batch_collate (output , batch )
250+ target = torch .tensor ([b [1 ] for b in batch ], dtype = torch .int64 )
251+ target = mixup_target (target , self .num_classes , lam , self .label_smoothing , device = 'cpu' )
252+ return output , target
55253
56- return tensor , target
0 commit comments