1919def transforms_noaug_train (
2020 img_size : Union [int , Tuple [int , int ]] = 224 ,
2121 interpolation : str = 'bilinear' ,
22- use_prefetcher : bool = False ,
2322 mean : Tuple [float , ...] = IMAGENET_DEFAULT_MEAN ,
2423 std : Tuple [float , ...] = IMAGENET_DEFAULT_STD ,
24+ use_prefetcher : bool = False ,
25+ normalize : bool = True ,
2526):
2627 """ No-augmentation image transforms for training.
2728
@@ -31,6 +32,7 @@ def transforms_noaug_train(
3132 mean: Image normalization mean.
3233 std: Image normalization standard deviation.
3334 use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize.
35+ normalize: Normalization tensor output w/ provided mean/std (if prefetcher not used).
3436
3537 Returns:
3638
@@ -45,6 +47,9 @@ def transforms_noaug_train(
4547 if use_prefetcher :
4648 # prefetcher and collate will handle tensor conversion and norm
4749 tfl += [ToNumpy ()]
50+ elif not normalize :
51+ # when normalize disabled, converted to tensor without scaling, keep original dtype
52+ tfl += [transforms .PILToTensor ()]
4853 else :
4954 tfl += [
5055 transforms .ToTensor (),
@@ -77,6 +82,7 @@ def transforms_imagenet_train(
7782 re_count : int = 1 ,
7883 re_num_splits : int = 0 ,
7984 use_prefetcher : bool = False ,
85+ normalize : bool = True ,
8086 separate : bool = False ,
8187):
8288 """ ImageNet-oriented image transforms for training.
@@ -103,6 +109,7 @@ def transforms_imagenet_train(
103109 re_count: Number of random erasing regions.
104110 re_num_splits: Control split of random erasing across batch size.
105111 use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize.
112+ normalize: Normalize tensor output w/ provided mean/std (if prefetcher not used).
106113 separate: Output transforms in 3-stage tuple.
107114
108115 Returns:
@@ -209,12 +216,15 @@ def transforms_imagenet_train(
209216 if use_prefetcher :
210217 # prefetcher and collate will handle tensor conversion and norm
211218 final_tfl += [ToNumpy ()]
219+ elif not normalize :
220+ # when normalize disable, converted to tensor without scaling, keeps original dtype
221+ final_tfl += [transforms .PILToTensor ()]
212222 else :
213223 final_tfl += [
214224 transforms .ToTensor (),
215225 transforms .Normalize (
216226 mean = torch .tensor (mean ),
217- std = torch .tensor (std )
227+ std = torch .tensor (std ),
218228 ),
219229 ]
220230 if re_prob > 0. :
@@ -243,6 +253,7 @@ def transforms_imagenet_eval(
243253 mean : Tuple [float , ...] = IMAGENET_DEFAULT_MEAN ,
244254 std : Tuple [float , ...] = IMAGENET_DEFAULT_STD ,
245255 use_prefetcher : bool = False ,
256+ normalize : bool = True ,
246257):
247258 """ ImageNet-oriented image transform for evaluation and inference.
248259
@@ -255,6 +266,7 @@ def transforms_imagenet_eval(
255266 mean: Image normalization mean.
256267 std: Image normalization standard deviation.
257268 use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize.
269+ normalize: Normalize tensor output w/ provided mean/std (if prefetcher not used).
258270
259271 Returns:
260272 Composed transform pipeline
@@ -304,13 +316,16 @@ def transforms_imagenet_eval(
304316 if use_prefetcher :
305317 # prefetcher and collate will handle tensor conversion and norm
306318 tfl += [ToNumpy ()]
319+ elif not normalize :
320+ # when normalize disabled, converted to tensor without scaling, keeps original dtype
321+ tfl += [transforms .PILToTensor ()]
307322 else :
308323 tfl += [
309324 transforms .ToTensor (),
310325 transforms .Normalize (
311326 mean = torch .tensor (mean ),
312327 std = torch .tensor (std ),
313- )
328+ ),
314329 ]
315330
316331 return transforms .Compose (tfl )
@@ -342,6 +357,7 @@ def create_transform(
342357 crop_border_pixels : Optional [int ] = None ,
343358 tf_preprocessing : bool = False ,
344359 use_prefetcher : bool = False ,
360+ normalize : bool = True ,
345361 separate : bool = False ,
346362):
347363 """
@@ -373,6 +389,7 @@ def create_transform(
373389 crop_border_pixels: Inference crop border of specified # pixels around edge of original image.
374390 tf_preprocessing: Use TF 1.0 inference preprocessing for testing model ports
375391 use_prefetcher: Pre-fetcher enabled. Do not convert image to tensor or normalize.
392+ normalize: Normalization tensor output w/ provided mean/std (if prefetcher not used).
376393 separate: Output transforms in 3-stage tuple.
377394
378395 Returns:
@@ -397,9 +414,10 @@ def create_transform(
397414 transform = transforms_noaug_train (
398415 img_size ,
399416 interpolation = interpolation ,
400- use_prefetcher = use_prefetcher ,
401417 mean = mean ,
402418 std = std ,
419+ use_prefetcher = use_prefetcher ,
420+ normalize = normalize ,
403421 )
404422 elif is_training :
405423 transform = transforms_imagenet_train (
@@ -415,26 +433,28 @@ def create_transform(
415433 gaussian_blur_prob = gaussian_blur_prob ,
416434 auto_augment = auto_augment ,
417435 interpolation = interpolation ,
418- use_prefetcher = use_prefetcher ,
419436 mean = mean ,
420437 std = std ,
421438 re_prob = re_prob ,
422439 re_mode = re_mode ,
423440 re_count = re_count ,
424441 re_num_splits = re_num_splits ,
442+ use_prefetcher = use_prefetcher ,
443+ normalize = normalize ,
425444 separate = separate ,
426445 )
427446 else :
428447 assert not separate , "Separate transforms not supported for validation preprocessing"
429448 transform = transforms_imagenet_eval (
430449 img_size ,
431450 interpolation = interpolation ,
432- use_prefetcher = use_prefetcher ,
433451 mean = mean ,
434452 std = std ,
435453 crop_pct = crop_pct ,
436454 crop_mode = crop_mode ,
437455 crop_border_pixels = crop_border_pixels ,
456+ use_prefetcher = use_prefetcher ,
457+ normalize = normalize ,
438458 )
439459
440460 return transform
0 commit comments