11import copy
2+ import importlib .util
23import unittest
34from typing import Dict
45
56import torch
67import torch_tensorrt as torchtrt
7- import torchvision .models as models
88from utils import same_output_format
99
10+ if importlib .util .find_spec ("torchvision" ):
11+ import torchvision .models as models
12+
1013
1114@unittest .skipIf (
1215 torchtrt .ENABLED_FEATURES .tensorrt_rtx ,
1316 "aten::adaptive_avg_pool2d is implemented via plugins which is not supported for tensorrt_rtx" ,
1417)
18+ @unittest .skipIf (
19+ not importlib .util .find_spec ("torchvision" ), "torchvision not installed"
20+ )
1521class TestInputTypeDefaultsFP32Model (unittest .TestCase ):
22+
1623 def test_input_use_default_fp32 (self ):
1724 self .model = models .resnet18 (pretrained = True ).eval ().to ("cuda" )
1825 self .input = torch .randn ((1 , 3 , 224 , 224 )).to ("cuda" )
@@ -60,6 +67,9 @@ class TestInputTypeDefaultsFP16Model(unittest.TestCase):
6067 torchtrt .ENABLED_FEATURES .tensorrt_rtx ,
6168 "aten::adaptive_avg_pool2d is implemented via plugins which is not supported for tensorrt_rtx" ,
6269 )
70+ @unittest .skipIf (
71+ not importlib .util .find_spec ("torchvision" ), "torchvision not installed"
72+ )
6373 def test_input_use_default_fp16 (self ):
6474 self .model = models .resnet18 (pretrained = True ).eval ().to ("cuda" )
6575 self .input = torch .randn ((1 , 3 , 224 , 224 )).to ("cuda" )
@@ -78,6 +88,9 @@ def test_input_use_default_fp16(self):
7888 torchtrt .ENABLED_FEATURES .tensorrt_rtx ,
7989 "aten::adaptive_avg_pool2d is implemented via plugins which is not supported for tensorrt_rtx" ,
8090 )
91+ @unittest .skipIf (
92+ not importlib .util .find_spec ("torchvision" ), "torchvision not installed"
93+ )
8194 def test_input_use_default_fp16_without_fp16_enabled (self ):
8295 self .model = models .resnet18 (pretrained = True ).eval ().to ("cuda" )
8396 self .input = torch .randn ((1 , 3 , 224 , 224 )).to ("cuda" )
@@ -94,6 +107,9 @@ def test_input_use_default_fp16_without_fp16_enabled(self):
94107 torchtrt .ENABLED_FEATURES .tensorrt_rtx ,
95108 "aten::adaptive_avg_pool2d is implemented via plugins which is not supported for tensorrt_rtx" ,
96109 )
110+ @unittest .skipIf (
111+ not importlib .util .find_spec ("torchvision" ), "torchvision not installed"
112+ )
97113 def test_input_respect_user_setting_fp16_weights_fp32_in (self ):
98114 self .model = models .resnet18 (pretrained = True ).eval ().to ("cuda" )
99115 self .input = torch .randn ((1 , 3 , 224 , 224 )).to ("cuda" )
@@ -113,6 +129,9 @@ def test_input_respect_user_setting_fp16_weights_fp32_in(self):
113129 torchtrt .ENABLED_FEATURES .tensorrt_rtx ,
114130 "aten::adaptive_avg_pool2d is implemented via plugins which is not supported for tensorrt_rtx" ,
115131 )
132+ @unittest .skipIf (
133+ not importlib .util .find_spec ("torchvision" ), "torchvision not installed"
134+ )
116135 def test_input_respect_user_setting_fp16_weights_fp32_in_non_constuctor (self ):
117136 self .model = models .resnet18 (pretrained = True ).eval ().to ("cuda" )
118137 self .input = torch .randn ((1 , 3 , 224 , 224 )).to ("cuda" )
0 commit comments