1+ import unittest
2+ import torch_tensorrt as torchtrt
3+ import torch
4+ import torchvision .models as models
5+ import copy
6+ from typing import Dict
7+
8+ class TestDevice (unittest .TestCase ):
9+
10+ def test_from_string_constructor (self ):
11+ device = torchtrt .Device ("cuda:0" )
12+ self .assertEqual (device .device_type , torchtrt .DeviceType .GPU )
13+ self .assertEqual (device .gpu_id , 0 )
14+
15+ device = torchtrt .Device ("gpu:1" )
16+ self .assertEqual (device .device_type , torchtrt .DeviceType .GPU )
17+ self .assertEqual (device .gpu_id , 1 )
18+
19+ def test_from_string_constructor_dla (self ):
20+ device = torchtrt .Device ("dla:0" )
21+ self .assertEqual (device .device_type , torchtrt .DeviceType .DLA )
22+ self .assertEqual (device .gpu_id , 0 )
23+ self .assertEqual (device .dla_core , 0 )
24+
25+ device = torchtrt .Device ("dla:1" , allow_gpu_fallback = True )
26+ self .assertEqual (device .device_type , torchtrt .DeviceType .DLA )
27+ self .assertEqual (device .gpu_id , 0 )
28+ self .assertEqual (device .dla_core , 1 )
29+ self .assertEqual (device .allow_gpu_fallback , True )
30+
31+ def test_kwargs_gpu (self ):
32+ device = torchtrt .Device (gpu_id = 0 )
33+ self .assertEqual (device .device_type , torchtrt .DeviceType .GPU )
34+ self .assertEqual (device .gpu_id , 0 )
35+
36+ def test_kwargs_dla_and_settings (self ):
37+ device = torchtrt .Device (dla_core = 1 , allow_gpu_fallback = False )
38+ self .assertEqual (device .device_type , torchtrt .DeviceType .DLA )
39+ self .assertEqual (device .gpu_id , 0 )
40+ self .assertEqual (device .dla_core , 1 )
41+ self .assertEqual (device .allow_gpu_fallback , False )
42+
43+ device = torchtrt .Device (gpu_id = 1 , dla_core = 0 , allow_gpu_fallback = True )
44+ self .assertEqual (device .device_type , torchtrt .DeviceType .DLA )
45+ self .assertEqual (device .gpu_id , 1 )
46+ self .assertEqual (device .dla_core , 0 )
47+ self .assertEqual (device .allow_gpu_fallback , True )
48+
49+ def test_from_torch (self ):
50+ device = torchtrt .Device ._from_torch_device (torch .device ("cuda:0" ))
51+ self .assertEqual (device .device_type , torchtrt .DeviceType .GPU )
52+ self .assertEqual (device .gpu_id , 0 )
53+
54+
55+ class TestInput (unittest .TestCase ):
56+
57+ def _verify_correctness (self , struct : torchtrt .Input , target : Dict ) -> bool :
58+ internal = struct ._to_internal ()
59+
60+ list_eq = lambda al , bl : all ([a == b for (a , b ) in zip (al , bl )])
61+
62+ eq = lambda a , b : a == b
63+
64+ def field_is_correct (field , equal_fn , a1 , a2 ):
65+ equal = equal_fn (a1 , a2 )
66+ if not equal :
67+ print ("\n Field {} is incorrect: {} != {}" .format (field , a1 , a2 ))
68+ return equal
69+
70+ min_ = field_is_correct ("min" , list_eq , internal .min , target ["min" ])
71+ opt_ = field_is_correct ("opt" , list_eq , internal .opt , target ["opt" ])
72+ max_ = field_is_correct ("max" , list_eq , internal .max , target ["max" ])
73+ is_dynamic_ = field_is_correct ("is_dynamic" , eq , internal .input_is_dynamic , target ["input_is_dynamic" ])
74+ explicit_set_dtype_ = field_is_correct ("explicit_dtype" , eq , internal ._explicit_set_dtype ,
75+ target ["explicit_set_dtype" ])
76+ dtype_ = field_is_correct ("dtype" , eq , int (internal .dtype ), int (target ["dtype" ]))
77+ format_ = field_is_correct ("format" , eq , int (internal .format ), int (target ["format" ]))
78+
79+ return all ([min_ , opt_ , max_ , is_dynamic_ , explicit_set_dtype_ , dtype_ , format_ ])
80+
81+ def test_infer_from_example_tensor (self ):
82+ shape = [1 , 3 , 255 , 255 ]
83+ target = {
84+ "min" : shape ,
85+ "opt" : shape ,
86+ "max" : shape ,
87+ "input_is_dynamic" : False ,
88+ "dtype" : torchtrt .dtype .half ,
89+ "format" : torchtrt .TensorFormat .contiguous ,
90+ "explicit_set_dtype" : True
91+ }
92+
93+ example_tensor = torch .randn (shape ).half ()
94+ i = torchtrt .Input ._from_tensor (example_tensor )
95+ self .assertTrue (self ._verify_correctness (i , target ))
96+
97+ def test_static_shape (self ):
98+ shape = [1 , 3 , 255 , 255 ]
99+ target = {
100+ "min" : shape ,
101+ "opt" : shape ,
102+ "max" : shape ,
103+ "input_is_dynamic" : False ,
104+ "dtype" : torchtrt .dtype .unknown ,
105+ "format" : torchtrt .TensorFormat .contiguous ,
106+ "explicit_set_dtype" : False
107+ }
108+
109+ i = torchtrt .Input (shape )
110+ self .assertTrue (self ._verify_correctness (i , target ))
111+
112+ i = torchtrt .Input (tuple (shape ))
113+ self .assertTrue (self ._verify_correctness (i , target ))
114+
115+ i = torchtrt .Input (torch .randn (shape ).shape )
116+ self .assertTrue (self ._verify_correctness (i , target ))
117+
118+ i = torchtrt .Input (shape = shape )
119+ self .assertTrue (self ._verify_correctness (i , target ))
120+
121+ i = torchtrt .Input (shape = tuple (shape ))
122+ self .assertTrue (self ._verify_correctness (i , target ))
123+
124+ i = torchtrt .Input (shape = torch .randn (shape ).shape )
125+ self .assertTrue (self ._verify_correctness (i , target ))
126+
127+ def test_data_type (self ):
128+ shape = [1 , 3 , 255 , 255 ]
129+ target = {
130+ "min" : shape ,
131+ "opt" : shape ,
132+ "max" : shape ,
133+ "input_is_dynamic" : False ,
134+ "dtype" : torchtrt .dtype .half ,
135+ "format" : torchtrt .TensorFormat .contiguous ,
136+ "explicit_set_dtype" : True
137+ }
138+
139+ i = torchtrt .Input (shape , dtype = torchtrt .dtype .half )
140+ self .assertTrue (self ._verify_correctness (i , target ))
141+
142+ i = torchtrt .Input (shape , dtype = torch .half )
143+ self .assertTrue (self ._verify_correctness (i , target ))
144+
145+ def test_tensor_format (self ):
146+ shape = [1 , 3 , 255 , 255 ]
147+ target = {
148+ "min" : shape ,
149+ "opt" : shape ,
150+ "max" : shape ,
151+ "input_is_dynamic" : False ,
152+ "dtype" : torchtrt .dtype .unknown ,
153+ "format" : torchtrt .TensorFormat .channels_last ,
154+ "explicit_set_dtype" : False
155+ }
156+
157+ i = torchtrt .Input (shape , format = torchtrt .TensorFormat .channels_last )
158+ self .assertTrue (self ._verify_correctness (i , target ))
159+
160+ i = torchtrt .Input (shape , format = torch .channels_last )
161+ self .assertTrue (self ._verify_correctness (i , target ))
162+
163+ def test_dynamic_shape (self ):
164+ min_shape = [1 , 3 , 128 , 128 ]
165+ opt_shape = [1 , 3 , 256 , 256 ]
166+ max_shape = [1 , 3 , 512 , 512 ]
167+ target = {
168+ "min" : min_shape ,
169+ "opt" : opt_shape ,
170+ "max" : max_shape ,
171+ "input_is_dynamic" : True ,
172+ "dtype" : torchtrt .dtype .unknown ,
173+ "format" : torchtrt .TensorFormat .contiguous ,
174+ "explicit_set_dtype" : False
175+ }
176+
177+ i = torchtrt .Input (min_shape = min_shape , opt_shape = opt_shape , max_shape = max_shape )
178+ self .assertTrue (self ._verify_correctness (i , target ))
179+
180+ i = torchtrt .Input (min_shape = tuple (min_shape ), opt_shape = tuple (opt_shape ), max_shape = tuple (max_shape ))
181+ self .assertTrue (self ._verify_correctness (i , target ))
182+
183+ tensor_shape = lambda shape : torch .randn (shape ).shape
184+ i = torchtrt .Input (min_shape = tensor_shape (min_shape ),
185+ opt_shape = tensor_shape (opt_shape ),
186+ max_shape = tensor_shape (max_shape ))
187+ self .assertTrue (self ._verify_correctness (i , target ))
188+
189+ if __name__ == "__main__" :
190+ unittest .main ()
0 commit comments