Skip to content

Commit a684cf0

Browse files
AdamRajfernv-kkudrynski
authored andcommitted
[ConvNets/PyT] Fix interpolation type from Image.* to InterpolationMode.*
1 parent f613b7c commit a684cf0

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

PyTorch/Classification/ConvNets/image_classification/dataloaders.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import torchvision.transforms as transforms
3535
from PIL import Image
3636
from functools import partial
37+
from torchvision.transforms.functional import InterpolationMode
3738

3839
from image_classification.autoaugment import AutoaugmentImageNetPolicy
3940

@@ -422,9 +423,10 @@ def get_pytorch_train_loader(
422423
prefetch_factor=2,
423424
memory_format=torch.contiguous_format,
424425
):
425-
interpolation = {"bicubic": Image.BICUBIC, "bilinear": Image.BILINEAR}[
426-
interpolation
427-
]
426+
interpolation = {
427+
"bicubic": InterpolationMode.BICUBIC,
428+
"bilinear": InterpolationMode.BILINEAR,
429+
}[interpolation]
428430
traindir = os.path.join(data_path, "train")
429431
transforms_list = [
430432
transforms.RandomResizedCrop(image_size, interpolation=interpolation),
@@ -474,9 +476,10 @@ def get_pytorch_val_loader(
474476
memory_format=torch.contiguous_format,
475477
prefetch_factor=2,
476478
):
477-
interpolation = {"bicubic": Image.BICUBIC, "bilinear": Image.BILINEAR}[
478-
interpolation
479-
]
479+
interpolation = {
480+
"bicubic": InterpolationMode.BICUBIC,
481+
"bilinear": InterpolationMode.BILINEAR,
482+
}[interpolation]
480483
valdir = os.path.join(data_path, "val")
481484
val_dataset = datasets.ImageFolder(
482485
valdir,

0 commit comments

Comments
 (0)