11import torch
22from parameterized import param , parameterized
33from torch .testing ._internal .common_utils import run_tests
4+
45from torch_tensorrt import Input
56
67from .harness import DispatchTestCase
@@ -15,6 +16,21 @@ class TestDeconvolutionConverter(DispatchTestCase):
1516 param ("non_zero_padding" , 1 , padding = 1 ),
1617 param ("dilation" , 1 , dilation = 2 ),
1718 param ("groups" , 1 , groups = 3 ),
19+ param ("output_padding_1" , 3 , stride = 2 , padding = 1 , output_padding = 1 ),
20+ param ("output_padding_2" , 3 , stride = 2 , padding = 2 , output_padding = 1 ),
21+ param ("output_padding_3" , 3 , stride = 2 , padding = 3 , output_padding = 1 ),
22+ param ("output_padding_4" , 3 , stride = 3 , padding = 2 , output_padding = 1 ),
23+ param ("output_padding_5" , 3 , stride = 3 , padding = 3 , output_padding = 1 ),
24+ param ("output_padding_6" , 3 , stride = 3 , padding = 3 , output_padding = 2 ),
25+ param (
26+ "combined_params" ,
27+ 3 ,
28+ stride = 3 ,
29+ padding = 3 ,
30+ dilation = 2 ,
31+ groups = 3 ,
32+ output_padding = 2 ,
33+ ),
1834 ]
1935 )
2036 def test_deconv1d (
@@ -26,6 +42,7 @@ def test_deconv1d(
2642 dilation = 1 ,
2743 groups = 1 ,
2844 bias = True ,
45+ output_padding = 0 ,
2946 ):
3047 class TestModule (torch .nn .Module ):
3148 def __init__ (self ):
@@ -36,9 +53,10 @@ def __init__(self):
3653 kernel_size = kernel_size ,
3754 stride = stride ,
3855 padding = padding ,
39- dilation = dilation ,
56+ output_padding = output_padding ,
4057 groups = groups ,
4158 bias = bias ,
59+ dilation = dilation ,
4260 )
4361
4462 def forward (self , x ):
@@ -101,6 +119,22 @@ def forward(self, x):
101119 param ("non_zero_padding" , 1 , padding = 1 ),
102120 param ("dilation" , 1 , dilation = 2 ),
103121 param ("groups" , 1 , groups = 3 ),
122+ param ("output_padding_1" , 3 , stride = 2 , padding = 1 , output_padding = 1 ),
123+ param ("output_padding_2" , 3 , stride = 2 , padding = 1 , output_padding = 1 ),
124+ param ("output_padding_3" , 3 , stride = 2 , padding = 2 , output_padding = 1 ),
125+ param ("output_padding_4" , 3 , stride = 2 , padding = 3 , output_padding = 1 ),
126+ param ("output_padding_5" , 3 , stride = 3 , padding = 2 , output_padding = 1 ),
127+ param ("output_padding_6" , 3 , stride = 3 , padding = 3 , output_padding = 1 ),
128+ param ("output_padding_7" , 3 , stride = 3 , padding = 3 , output_padding = 2 ),
129+ param (
130+ "combined_params" ,
131+ 3 ,
132+ stride = 3 ,
133+ padding = 3 ,
134+ dilation = 2 ,
135+ groups = 3 ,
136+ output_padding = 2 ,
137+ ),
104138 ]
105139 )
106140 def test_deconv2d (
@@ -112,6 +146,7 @@ def test_deconv2d(
112146 dilation = 1 ,
113147 groups = 1 ,
114148 bias = True ,
149+ output_padding = 0 ,
115150 ):
116151 class TestModule (torch .nn .Module ):
117152 def __init__ (self ):
@@ -122,9 +157,10 @@ def __init__(self):
122157 kernel_size = kernel_size ,
123158 stride = stride ,
124159 padding = padding ,
125- dilation = dilation ,
160+ output_padding = output_padding ,
126161 groups = groups ,
127162 bias = bias ,
163+ dilation = dilation ,
128164 )
129165
130166 def forward (self , x ):
@@ -172,6 +208,19 @@ def forward(self, x):
172208 param ("non_zero_padding" , 1 , padding = 1 ),
173209 param ("dilation" , 1 , dilation = 2 ),
174210 param ("groups" , 1 , groups = 3 ),
211+ param ("output_padding_1" , 3 , stride = 2 , padding = 1 , output_padding = 1 ),
212+ param ("output_padding_2" , 3 , stride = 2 , padding = 2 , output_padding = 1 ),
213+ param ("output_padding_3" , 3 , stride = 3 , padding = 3 , output_padding = 1 ),
214+ param ("output_padding_4" , 3 , stride = 3 , padding = 3 , output_padding = 2 ),
215+ param (
216+ "combined_params" ,
217+ 3 ,
218+ stride = 3 ,
219+ padding = 3 ,
220+ dilation = 2 ,
221+ groups = 3 ,
222+ output_padding = 2 ,
223+ ),
175224 ]
176225 )
177226 def test_deconv3d (
@@ -183,6 +232,7 @@ def test_deconv3d(
183232 dilation = 1 ,
184233 groups = 1 ,
185234 bias = True ,
235+ output_padding = 0 ,
186236 ):
187237 class TestModule (torch .nn .Module ):
188238 def __init__ (self ):
@@ -193,9 +243,10 @@ def __init__(self):
193243 kernel_size = kernel_size ,
194244 stride = stride ,
195245 padding = padding ,
196- dilation = dilation ,
246+ output_padding = output_padding ,
197247 groups = groups ,
198248 bias = bias ,
249+ dilation = dilation ,
199250 )
200251
201252 def forward (self , x ):
0 commit comments