11from typing import Optional , Sequence , Union
22
3+ import numpy as np
34import tensorrt as trt
45from torch .fx .node import Target
56from torch_tensorrt .dynamo ._SourceIR import SourceIR
7+ from torch_tensorrt .dynamo .conversion import impl
68from torch_tensorrt .dynamo .conversion ._ConversionContext import ConversionContext
7- from torch_tensorrt .dynamo .conversion .converter_utils import get_trt_tensor
8- from torch_tensorrt .fx .converters .converter_utils import (
9- has_dynamic_shape ,
9+ from torch_tensorrt .dynamo .conversion .converter_utils import (
10+ get_trt_tensor ,
1011 set_layer_name ,
1112)
12- from torch_tensorrt .fx .types import TRTTensor
13+ from torch_tensorrt .dynamo .conversion .impl .shape import get_shape_with_dynamic_shape
14+ from torch_tensorrt .dynamo .types import TRTTensor
1315
1416"""
1517Note: IPaddingLayer is deprecated in TensorRT 8.2 and will be removed in TensorRT 10.0.
1820"""
1921
2022
21- def constant_padNd (
23+ def get_padded_shape_tensors (
2224 ctx : ConversionContext ,
2325 target : Union [Target , str ],
2426 source_ir : Optional [SourceIR ],
2527 name : str ,
2628 input : TRTTensor ,
2729 pad : Sequence [int ],
28- value : Union [int , float ] = 0 ,
2930) -> TRTTensor :
30- if has_dynamic_shape (input .shape ):
31- assert input .shape [1 ] != - 1 , "Channel dim can't be dynamic for padding."
32-
3331 rank = len (input .shape )
34-
3532 if len (pad ) // 2 > rank :
3633 raise RuntimeError (
37- f"Trying to pad last { len (pad ) // 2 } dimension but the input only has { rank } dimension ."
34+ f"Trying to pad last { len (pad ) // 2 } dimensions but the input only has { rank } dimensions ."
3835 )
3936
37+ input_shape_tensor = get_shape_with_dynamic_shape (
38+ ctx ,
39+ target ,
40+ source_ir ,
41+ name + "_input_shape" ,
42+ input .shape ,
43+ input ,
44+ )
45+ padded_shape_tensor = input_shape_tensor
46+
4047 start_list = [0 ] * rank
41- new_shape = list (input .shape )
48+ for i in range (len (pad ) // 2 ):
49+ dim_index = rank - (i + 1 )
50+ pad_before = pad [i * 2 ]
51+ pad_after = pad [i * 2 + 1 ]
4252
43- for i in range (0 , len (pad ) // 2 ):
44- start_list [- i - 1 ] = - pad [i * 2 ]
45- new_shape [- i - 1 ] += pad [i * 2 ] + pad [i * 2 + 1 ]
53+ pad_sum = get_trt_tensor (
54+ ctx , pad_before + pad_after , f"{ name } _pad_sum_{ i } " , dtype = np .int32
55+ )
56+ dim_shape = ctx .net .add_slice (
57+ input_shape_tensor ,
58+ start = (dim_index ,),
59+ shape = (1 ,),
60+ stride = (1 ,),
61+ ).get_output (0 )
62+
63+ new_dim_shape = impl .elementwise .add (
64+ ctx , target , source_ir , f"{ name } _shape_dim_{ i } " , dim_shape , pad_sum
65+ )
66+ start_list [dim_index ] = - pad_before
67+
68+ slices = []
69+ for j in range (rank ):
70+ if j == dim_index :
71+ slices .append (new_dim_shape )
72+ else :
73+ slices .append (
74+ ctx .net .add_slice (
75+ padded_shape_tensor ,
76+ start = (j ,),
77+ shape = (1 ,),
78+ stride = (1 ,),
79+ ).get_output (0 )
80+ )
81+ padded_shape_tensor = impl .cat .cat (
82+ ctx , target , source_ir , f"{ name } _cat" , slices , 0
83+ )
84+
85+ start_indices_tensor = get_trt_tensor (
86+ ctx ,
87+ np .array (start_list , dtype = np .int32 ),
88+ f"{ name } _start_indices_tensor" ,
89+ dtype = np .int32 ,
90+ )
91+
92+ return start_indices_tensor , padded_shape_tensor
93+
94+
95+ def constant_padNd (
96+ ctx : ConversionContext ,
97+ target : Union [Target , str ],
98+ source_ir : Optional [SourceIR ],
99+ name : str ,
100+ input : TRTTensor ,
101+ pad : Sequence [int ],
102+ value : Union [int , float ] = 0 ,
103+ ) -> TRTTensor :
104+
105+ rank = len (input .shape )
106+
107+ start_indices_tensor , padded_shape_tensor = get_padded_shape_tensors (
108+ ctx , target , source_ir , name , input , pad
109+ )
46110
47111 stride_list = [1 ] * rank
112+ stride_tensor = get_trt_tensor (
113+ ctx ,
114+ np .array (stride_list , dtype = np .int32 ),
115+ f"{ name } _stride_tensor" ,
116+ dtype = np .int32 ,
117+ )
118+
48119 layer = ctx .net .add_slice (
49- input ,
50- start = tuple (start_list ),
51- shape = tuple (new_shape ),
52- stride = tuple (stride_list ),
120+ input , start = trt .Dims (), shape = trt .Dims (), stride = trt .Dims ()
53121 )
122+ layer .set_input (1 , start_indices_tensor )
123+ layer .set_input (2 , padded_shape_tensor )
124+ layer .set_input (3 , stride_tensor )
125+
54126 value_const = get_trt_tensor (ctx , value , f"{ name } _value" , input .dtype )
55127 layer .set_input (4 , value_const )
56128 layer .mode = trt .SampleMode .FILL
@@ -67,30 +139,26 @@ def reflection_padNd(
67139 input : TRTTensor ,
68140 padding : Sequence [int ],
69141) -> TRTTensor :
70- if has_dynamic_shape (input .shape ):
71- assert input .shape [1 ] != - 1 , "Channel dim can't be dynamic for padding."
72-
73142 rank = len (input .shape )
74143
75- if len (padding ) // 2 > rank :
76- raise RuntimeError (
77- f"Trying to pad last { len (padding ) // 2 } dimension but the input only has { rank } dimension."
78- )
79-
80- start_list = [0 ] * rank
81- new_shape = list (input .shape )
82-
83- for i in range (0 , len (padding ) // 2 ):
84- start_list [- i - 1 ] = - padding [i * 2 ]
85- new_shape [- i - 1 ] += padding [i * 2 ] + padding [i * 2 + 1 ]
144+ start_indices_tensor , padded_shape_tensor = get_padded_shape_tensors (
145+ ctx , target , source_ir , name , input , padding
146+ )
86147
87148 stride_list = [1 ] * rank
149+ stride_tensor = get_trt_tensor (
150+ ctx ,
151+ np .array (stride_list , dtype = np .int32 ),
152+ f"{ name } _stride_tensor" ,
153+ dtype = np .int32 ,
154+ )
155+
88156 layer = ctx .net .add_slice (
89- input ,
90- start = tuple (start_list ),
91- shape = tuple (new_shape ),
92- stride = tuple (stride_list ),
157+ input , start = trt .Dims (), shape = trt .Dims (), stride = trt .Dims ()
93158 )
159+ layer .set_input (1 , start_indices_tensor )
160+ layer .set_input (2 , padded_shape_tensor )
161+ layer .set_input (3 , stride_tensor )
94162 layer .mode = trt .SampleMode .REFLECT
95163
96164 set_layer_name (layer , target , name , source_ir )
@@ -105,30 +173,26 @@ def replication_padNd(
105173 input : TRTTensor ,
106174 padding : Sequence [int ],
107175) -> TRTTensor :
108- if has_dynamic_shape (input .shape ):
109- assert input .shape [1 ] != - 1 , "Channel dim can't be dynamic for padding."
110-
111176 rank = len (input .shape )
112177
113- if len (padding ) // 2 > rank :
114- raise RuntimeError (
115- f"Trying to pad last { len (padding ) // 2 } dimension but the input only has { rank } dimension."
116- )
117-
118- start_list = [0 ] * rank
119- new_shape = list (input .shape )
120-
121- for i in range (0 , len (padding ) // 2 ):
122- start_list [- i - 1 ] = - padding [i * 2 ]
123- new_shape [- i - 1 ] += padding [i * 2 ] + padding [i * 2 + 1 ]
178+ start_indices_tensor , padded_shape_tensor = get_padded_shape_tensors (
179+ ctx , target , source_ir , name , input , padding
180+ )
124181
125182 stride_list = [1 ] * rank
183+ stride_tensor = get_trt_tensor (
184+ ctx ,
185+ np .array (stride_list , dtype = np .int32 ),
186+ f"{ name } _stride_tensor" ,
187+ dtype = np .int32 ,
188+ )
189+
126190 layer = ctx .net .add_slice (
127- input ,
128- start = tuple (start_list ),
129- shape = tuple (new_shape ),
130- stride = tuple (stride_list ),
191+ input , start = trt .Dims (), shape = trt .Dims (), stride = trt .Dims ()
131192 )
193+ layer .set_input (1 , start_indices_tensor )
194+ layer .set_input (2 , padded_shape_tensor )
195+ layer .set_input (3 , stride_tensor )
132196 layer .mode = trt .SampleMode .CLAMP
133197
134198 set_layer_name (layer , target , name , source_ir )
@@ -141,32 +205,28 @@ def circular_padNd(
141205 source_ir : Optional [SourceIR ],
142206 name : str ,
143207 input : TRTTensor ,
144- pad : Sequence [int ],
208+ padding : Sequence [int ],
145209) -> TRTTensor :
146- if has_dynamic_shape (input .shape ):
147- assert input .shape [1 ] != - 1 , "Channel dim can't be dynamic for padding."
148-
149210 rank = len (input .shape )
150211
151- if len (pad ) // 2 > rank :
152- raise RuntimeError (
153- f"Trying to pad last { len (pad ) // 2 } dimension but the input only has { rank } dimension."
154- )
155-
156- start_list = [0 ] * rank
157- new_shape = list (input .shape )
158-
159- for i in range (0 , len (pad ) // 2 ):
160- start_list [- i - 1 ] = - pad [i * 2 ]
161- new_shape [- i - 1 ] += pad [i * 2 ] + pad [i * 2 + 1 ]
212+ start_indices_tensor , padded_shape_tensor = get_padded_shape_tensors (
213+ ctx , target , source_ir , name , input , padding
214+ )
162215
163216 stride_list = [1 ] * rank
217+ stride_tensor = get_trt_tensor (
218+ ctx ,
219+ np .array (stride_list , dtype = np .int32 ),
220+ f"{ name } _stride_tensor" ,
221+ dtype = np .int32 ,
222+ )
223+
164224 layer = ctx .net .add_slice (
165- input ,
166- start = tuple (start_list ),
167- shape = tuple (new_shape ),
168- stride = tuple (stride_list ),
225+ input , start = trt .Dims (), shape = trt .Dims (), stride = trt .Dims ()
169226 )
227+ layer .set_input (1 , start_indices_tensor )
228+ layer .set_input (2 , padded_shape_tensor )
229+ layer .set_input (3 , stride_tensor )
170230 layer .mode = trt .SampleMode .WRAP
171231
172232 set_layer_name (layer , target , name , source_ir )
0 commit comments