From 3aea5c0d77f15bf043b6bdd3f7cc699d99acceec Mon Sep 17 00:00:00 2001 From: Mahdi Lamb Date: Wed, 31 Aug 2022 18:07:17 +0100 Subject: [PATCH 1/2] Update interpolate.py Modify to work in python where Sequence is in collections.abc --- torch2trt/converters/interpolate.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/torch2trt/converters/interpolate.py b/torch2trt/converters/interpolate.py index bee810d0..38190f9f 100644 --- a/torch2trt/converters/interpolate.py +++ b/torch2trt/converters/interpolate.py @@ -2,7 +2,10 @@ import torch.nn as nn from torch2trt.torch2trt import * from torch2trt.module_test import add_module_test -import collections +try: + from collections import Sequence +except ImportError: + from collections.abc import Sequence def has_interpolate_plugin(): @@ -66,7 +69,7 @@ def convert_interpolate_trt7(ctx): shape = size if shape != None: - if isinstance(shape, collections.Sequence): + if isinstance(shape, Sequence): shape = [input.size(0), input.size(1)] + list(shape) else: shape = [input.size(0), input.size(1)] + [shape] * input_dim @@ -75,7 +78,7 @@ def convert_interpolate_trt7(ctx): scales = scale_factor if scales != None: - if not isinstance(scales, collections.Sequence): + if not isinstance(scales, Sequence): scales = [scales] * input_dim layer.scales = [1, 1] + list(scales) From c868d6c67b5741bc76b2314c9584489d547cd0b6 Mon Sep 17 00:00:00 2001 From: Mahdi Lamb Date: Wed, 31 Aug 2022 18:12:06 +0100 Subject: [PATCH 2/2] Update interpolate.py --- torch2trt/converters/interpolate.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch2trt/converters/interpolate.py b/torch2trt/converters/interpolate.py index 38190f9f..22f8bcb9 100644 --- a/torch2trt/converters/interpolate.py +++ b/torch2trt/converters/interpolate.py @@ -3,9 +3,9 @@ from torch2trt.torch2trt import * from torch2trt.module_test import add_module_test try: - from collections import Sequence + from collections import Sequence except ImportError: - from collections.abc import Sequence + from collections.abc import Sequence def has_interpolate_plugin():