Skip to content

Commit 99a09eb

Browse files
committed
Update old FastCollateMixup to accept torch tensor inputs instead of only numpy arrays
1 parent 4ff865c commit 99a09eb

File tree

1 file changed

+52
-19
lines changed

1 file changed

+52
-19
lines changed

timm/data/mixup.py

Lines changed: 52 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -229,29 +229,41 @@ def _mix_elem_collate(self, output, batch, half=False):
229229
num_elem = batch_size // 2 if half else batch_size
230230
assert len(output) == num_elem
231231
lam_batch, use_cutmix = self._params_per_elem(num_elem)
232+
is_np = isinstance(batch[0][0], np.ndarray)
233+
232234
for i in range(num_elem):
233235
j = batch_size - i - 1
234236
lam = lam_batch[i]
235237
mixed = batch[i][0]
236238
if lam != 1.:
237239
if use_cutmix[i]:
238240
if not half:
239-
mixed = mixed.copy()
241+
mixed = mixed.copy() if is_np else mixed.clone()
240242
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
241-
output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
243+
output.shape,
244+
lam,
245+
ratio_minmax=self.cutmix_minmax,
246+
correct_lam=self.correct_lam,
247+
)
242248
mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh]
243249
lam_batch[i] = lam
244250
else:
245-
mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
246-
np.rint(mixed, out=mixed)
247-
output[i] += torch.from_numpy(mixed.astype(np.uint8))
251+
if is_np:
252+
mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
253+
np.rint(mixed, out=mixed)
254+
else:
255+
mixed = mixed.float() * lam + batch[j][0].float() * (1 - lam)
256+
torch.round(mixed, out=mixed)
257+
output[i] += torch.from_numpy(mixed.astype(np.uint8)) if is_np else mixed.byte()
248258
if half:
249259
lam_batch = np.concatenate((lam_batch, np.ones(num_elem)))
250260
return torch.tensor(lam_batch).unsqueeze(1)
251261

252262
def _mix_pair_collate(self, output, batch):
253263
batch_size = len(batch)
254264
lam_batch, use_cutmix = self._params_per_elem(batch_size // 2)
265+
is_np = isinstance(batch[0][0], np.ndarray)
266+
255267
for i in range(batch_size // 2):
256268
j = batch_size - i - 1
257269
lam = lam_batch[i]
@@ -261,39 +273,60 @@ def _mix_pair_collate(self, output, batch):
261273
if lam < 1.:
262274
if use_cutmix[i]:
263275
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
264-
output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
265-
patch_i = mixed_i[:, yl:yh, xl:xh].copy()
276+
output.shape,
277+
lam,
278+
ratio_minmax=self.cutmix_minmax,
279+
correct_lam=self.correct_lam,
280+
)
281+
patch_i = mixed_i[:, yl:yh, xl:xh].copy() if is_np else mixed_i[:, yl:yh, xl:xh].clone()
266282
mixed_i[:, yl:yh, xl:xh] = mixed_j[:, yl:yh, xl:xh]
267283
mixed_j[:, yl:yh, xl:xh] = patch_i
268284
lam_batch[i] = lam
269285
else:
270-
mixed_temp = mixed_i.astype(np.float32) * lam + mixed_j.astype(np.float32) * (1 - lam)
271-
mixed_j = mixed_j.astype(np.float32) * lam + mixed_i.astype(np.float32) * (1 - lam)
272-
mixed_i = mixed_temp
273-
np.rint(mixed_j, out=mixed_j)
274-
np.rint(mixed_i, out=mixed_i)
275-
output[i] += torch.from_numpy(mixed_i.astype(np.uint8))
276-
output[j] += torch.from_numpy(mixed_j.astype(np.uint8))
286+
if is_np:
287+
mixed_temp = mixed_i.astype(np.float32) * lam + mixed_j.astype(np.float32) * (1 - lam)
288+
mixed_j = mixed_j.astype(np.float32) * lam + mixed_i.astype(np.float32) * (1 - lam)
289+
mixed_i = mixed_temp
290+
np.rint(mixed_j, out=mixed_j)
291+
np.rint(mixed_i, out=mixed_i)
292+
else:
293+
mixed_temp = mixed_i.float() * lam + mixed_j.float() * (1 - lam)
294+
mixed_j = mixed_j.float() * lam + mixed_i.float() * (1 - lam)
295+
mixed_i = mixed_temp
296+
torch.round(mixed_j, out=mixed_j)
297+
torch.round(mixed_i, out=mixed_i)
298+
output[i] += torch.from_numpy(mixed_i.astype(np.uint8)) if is_np else mixed_i.byte()
299+
output[j] += torch.from_numpy(mixed_j.astype(np.uint8)) if is_np else mixed_j.byte()
277300
lam_batch = np.concatenate((lam_batch, lam_batch[::-1]))
278301
return torch.tensor(lam_batch).unsqueeze(1)
279302

280303
def _mix_batch_collate(self, output, batch):
281304
batch_size = len(batch)
282305
lam, use_cutmix = self._params_per_batch()
306+
is_np = isinstance(batch[0][0], np.ndarray)
307+
283308
if use_cutmix:
284309
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
285-
output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
310+
output.shape,
311+
lam,
312+
ratio_minmax=self.cutmix_minmax,
313+
correct_lam=self.correct_lam,
314+
)
286315
for i in range(batch_size):
287316
j = batch_size - i - 1
288317
mixed = batch[i][0]
289318
if lam != 1.:
290319
if use_cutmix:
291-
mixed = mixed.copy() # don't want to modify the original while iterating
320+
mixed = mixed.copy() if is_np else mixed.clone() # don't want to modify the original while iterating
292321
mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh]
293322
else:
294-
mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
295-
np.rint(mixed, out=mixed)
296-
output[i] += torch.from_numpy(mixed.astype(np.uint8))
323+
if is_np:
324+
mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
325+
np.rint(mixed, out=mixed)
326+
else:
327+
mixed = mixed.float() * lam + batch[j][0].float() * (1 - lam)
328+
torch.round(mixed, out=mixed)
329+
output[i] += torch.from_numpy(mixed.astype(np.uint8)) if is_np else mixed.byte()
297330
return lam
298331

299332
def __call__(self, batch, _=None):

0 commit comments

Comments
 (0)