diff --git a/torch2trt/converters/interpolate.py b/torch2trt/converters/interpolate.py index bee810d0..22f8bcb9 100644 --- a/torch2trt/converters/interpolate.py +++ b/torch2trt/converters/interpolate.py @@ -2,7 +2,10 @@ import torch.nn as nn from torch2trt.torch2trt import * from torch2trt.module_test import add_module_test -import collections +try: + from collections import Sequence +except ImportError: + from collections.abc import Sequence def has_interpolate_plugin(): @@ -66,7 +69,7 @@ def convert_interpolate_trt7(ctx): shape = size if shape != None: - if isinstance(shape, collections.Sequence): + if isinstance(shape, Sequence): shape = [input.size(0), input.size(1)] + list(shape) else: shape = [input.size(0), input.size(1)] + [shape] * input_dim @@ -75,7 +78,7 @@ def convert_interpolate_trt7(ctx): scales = scale_factor if scales != None: - if not isinstance(scales, collections.Sequence): + if not isinstance(scales, Sequence): scales = [scales] * input_dim layer.scales = [1, 1] + list(scales)