@@ -229,29 +229,41 @@ def _mix_elem_collate(self, output, batch, half=False):
229229 num_elem = batch_size // 2 if half else batch_size
230230 assert len (output ) == num_elem
231231 lam_batch , use_cutmix = self ._params_per_elem (num_elem )
232+ is_np = isinstance (batch [0 ][0 ], np .ndarray )
233+
232234 for i in range (num_elem ):
233235 j = batch_size - i - 1
234236 lam = lam_batch [i ]
235237 mixed = batch [i ][0 ]
236238 if lam != 1. :
237239 if use_cutmix [i ]:
238240 if not half :
239- mixed = mixed .copy ()
241+ mixed = mixed .copy () if is_np else mixed . clone ()
240242 (yl , yh , xl , xh ), lam = cutmix_bbox_and_lam (
241- output .shape , lam , ratio_minmax = self .cutmix_minmax , correct_lam = self .correct_lam )
243+ output .shape ,
244+ lam ,
245+ ratio_minmax = self .cutmix_minmax ,
246+ correct_lam = self .correct_lam ,
247+ )
242248 mixed [:, yl :yh , xl :xh ] = batch [j ][0 ][:, yl :yh , xl :xh ]
243249 lam_batch [i ] = lam
244250 else :
245- mixed = mixed .astype (np .float32 ) * lam + batch [j ][0 ].astype (np .float32 ) * (1 - lam )
246- np .rint (mixed , out = mixed )
247- output [i ] += torch .from_numpy (mixed .astype (np .uint8 ))
251+ if is_np :
252+ mixed = mixed .astype (np .float32 ) * lam + batch [j ][0 ].astype (np .float32 ) * (1 - lam )
253+ np .rint (mixed , out = mixed )
254+ else :
255+ mixed = mixed .float () * lam + batch [j ][0 ].float () * (1 - lam )
256+ torch .round (mixed , out = mixed )
257+ output [i ] += torch .from_numpy (mixed .astype (np .uint8 )) if is_np else mixed .byte ()
248258 if half :
249259 lam_batch = np .concatenate ((lam_batch , np .ones (num_elem )))
250260 return torch .tensor (lam_batch ).unsqueeze (1 )
251261
252262 def _mix_pair_collate (self , output , batch ):
253263 batch_size = len (batch )
254264 lam_batch , use_cutmix = self ._params_per_elem (batch_size // 2 )
265+ is_np = isinstance (batch [0 ][0 ], np .ndarray )
266+
255267 for i in range (batch_size // 2 ):
256268 j = batch_size - i - 1
257269 lam = lam_batch [i ]
@@ -261,39 +273,60 @@ def _mix_pair_collate(self, output, batch):
261273 if lam < 1. :
262274 if use_cutmix [i ]:
263275 (yl , yh , xl , xh ), lam = cutmix_bbox_and_lam (
264- output .shape , lam , ratio_minmax = self .cutmix_minmax , correct_lam = self .correct_lam )
265- patch_i = mixed_i [:, yl :yh , xl :xh ].copy ()
276+ output .shape ,
277+ lam ,
278+ ratio_minmax = self .cutmix_minmax ,
279+ correct_lam = self .correct_lam ,
280+ )
281+ patch_i = mixed_i [:, yl :yh , xl :xh ].copy () if is_np else mixed_i [:, yl :yh , xl :xh ].clone ()
266282 mixed_i [:, yl :yh , xl :xh ] = mixed_j [:, yl :yh , xl :xh ]
267283 mixed_j [:, yl :yh , xl :xh ] = patch_i
268284 lam_batch [i ] = lam
269285 else :
270- mixed_temp = mixed_i .astype (np .float32 ) * lam + mixed_j .astype (np .float32 ) * (1 - lam )
271- mixed_j = mixed_j .astype (np .float32 ) * lam + mixed_i .astype (np .float32 ) * (1 - lam )
272- mixed_i = mixed_temp
273- np .rint (mixed_j , out = mixed_j )
274- np .rint (mixed_i , out = mixed_i )
275- output [i ] += torch .from_numpy (mixed_i .astype (np .uint8 ))
276- output [j ] += torch .from_numpy (mixed_j .astype (np .uint8 ))
286+ if is_np :
287+ mixed_temp = mixed_i .astype (np .float32 ) * lam + mixed_j .astype (np .float32 ) * (1 - lam )
288+ mixed_j = mixed_j .astype (np .float32 ) * lam + mixed_i .astype (np .float32 ) * (1 - lam )
289+ mixed_i = mixed_temp
290+ np .rint (mixed_j , out = mixed_j )
291+ np .rint (mixed_i , out = mixed_i )
292+ else :
293+ mixed_temp = mixed_i .float () * lam + mixed_j .float () * (1 - lam )
294+ mixed_j = mixed_j .float () * lam + mixed_i .float () * (1 - lam )
295+ mixed_i = mixed_temp
296+ torch .round (mixed_j , out = mixed_j )
297+ torch .round (mixed_i , out = mixed_i )
298+ output [i ] += torch .from_numpy (mixed_i .astype (np .uint8 )) if is_np else mixed_i .byte ()
299+ output [j ] += torch .from_numpy (mixed_j .astype (np .uint8 )) if is_np else mixed_j .byte ()
277300 lam_batch = np .concatenate ((lam_batch , lam_batch [::- 1 ]))
278301 return torch .tensor (lam_batch ).unsqueeze (1 )
279302
280303 def _mix_batch_collate (self , output , batch ):
281304 batch_size = len (batch )
282305 lam , use_cutmix = self ._params_per_batch ()
306+ is_np = isinstance (batch [0 ][0 ], np .ndarray )
307+
283308 if use_cutmix :
284309 (yl , yh , xl , xh ), lam = cutmix_bbox_and_lam (
285- output .shape , lam , ratio_minmax = self .cutmix_minmax , correct_lam = self .correct_lam )
310+ output .shape ,
311+ lam ,
312+ ratio_minmax = self .cutmix_minmax ,
313+ correct_lam = self .correct_lam ,
314+ )
286315 for i in range (batch_size ):
287316 j = batch_size - i - 1
288317 mixed = batch [i ][0 ]
289318 if lam != 1. :
290319 if use_cutmix :
291- mixed = mixed .copy () # don't want to modify the original while iterating
320+ mixed = mixed .copy () if is_np else mixed . clone () # don't want to modify the original while iterating
292321 mixed [:, yl :yh , xl :xh ] = batch [j ][0 ][:, yl :yh , xl :xh ]
293322 else :
294- mixed = mixed .astype (np .float32 ) * lam + batch [j ][0 ].astype (np .float32 ) * (1 - lam )
295- np .rint (mixed , out = mixed )
296- output [i ] += torch .from_numpy (mixed .astype (np .uint8 ))
323+ if is_np :
324+ mixed = mixed .astype (np .float32 ) * lam + batch [j ][0 ].astype (np .float32 ) * (1 - lam )
325+ np .rint (mixed , out = mixed )
326+ else :
327+ mixed = mixed .float () * lam + batch [j ][0 ].float () * (1 - lam )
328+ torch .round (mixed , out = mixed )
329+ output [i ] += torch .from_numpy (mixed .astype (np .uint8 )) if is_np else mixed .byte ()
297330 return lam
298331
299332 def __call__ (self , batch , _ = None ):
0 commit comments