@@ -506,8 +506,17 @@ def acc_ops_size(
506506 kwargs : Dict [str , Argument ],
507507 name : str ,
508508) -> Union [TRTTensor , Sequence [TRTTensor ]]:
509- input_val = kwargs ["input" ]
510-
509+ input_t = kwargs ["input" ]
510+ if type (input_t ) == torch .nn .Parameter or type (input_t ) == torch .Tensor :
511+ if (
512+ not has_dynamic_shape (input_t .shape )
513+ and network .has_implicit_batch_dimension
514+ ):
515+ return torch .Size ((IMPLICIT_BATCH_DIM ,) + tuple (input_t .shape ))
516+ return input_t .shape
517+
518+ # input_val = get_trt_tensor(network, input_t, f"{name}_input_t")
519+ input_val = input_t
511520 if not isinstance (input_val , TRTTensor ):
512521 raise RuntimeError (
513522 f"size received input { input_val } that is not part "
@@ -779,13 +788,8 @@ def acc_ops_tile(
779788 kwargs : Dict [str , Argument ],
780789 name : str ,
781790) -> Union [TRTTensor , Sequence [TRTTensor ]]:
782- input_val = kwargs ["input" ]
783-
784- if not isinstance (input_val , TRTTensor ):
785- raise RuntimeError (
786- f"tile received input { input_val } that is not part "
787- "of the TensorRT region!"
788- )
791+ input_t = kwargs ["input" ]
792+ input_val = get_trt_tensor (network , input_t , f"{ name } _input" )
789793
790794 dims = tuple (cast (Sequence [int ], kwargs ["dims" ]))
791795 n_input_dims = len (input_val .shape ) + (
@@ -822,9 +826,28 @@ def acc_ops_tile(
822826 if network .has_implicit_batch_dimension :
823827 assert dims [0 ] == 1 , "Can't tile the batch dim when it's implicit."
824828 dims = dims [1 :]
825-
826829 starts = [0 ] * len (dims )
827- shapes = [i * j for i , j in zip (input_val .shape , dims )] # type: ignore[union-attr]
830+ shapes = []
831+ if all (isinstance (d , int ) for d in dims ):
832+ shapes = [i * j for i , j in zip (input_val .shape , dims )] # type: ignore[union-attr]
833+ else :
834+ shape = []
835+ for i , (s , d ) in enumerate (zip (input_val .shape , dims )):
836+ if isinstance (d , TRTTensor ) and len (d .shape ) == 0 :
837+ d = prepend_ones (network , d , f"{ name } _{ i } " , 1 )
838+ else :
839+ d = get_trt_tensor (network , d , f"{ name } _{ i } " )
840+ shape .append (d )
841+ mul = add_binary_elementwise_layer (
842+ network ,
843+ s ,
844+ d ,
845+ trt .ElementWiseOperation .PROD ,
846+ target ,
847+ f"{ name } _mul_{ i } " ,
848+ )
849+ shapes .append (mul )
850+ dims = shape
828851 # If there's dynmaic dim then there would be negative dims in shapes which is not allowed.
829852 # Here we build a dummy shapes array.
830853 if has_dynamic_shape (input_val .shape ): # type: ignore[union-attr]
@@ -838,9 +861,16 @@ def acc_ops_tile(
838861 starts_tensor = network .add_constant (
839862 (len (dims ),), np .ascontiguousarray ([0 ] * len (dims ), np .int32 )
840863 ).get_output (0 )
841- dims_tensor = network .add_constant (
842- (len (dims ),), np .ascontiguousarray (dims , np .int32 )
843- ).get_output (0 )
864+ if all (isinstance (d , int ) for d in dims ):
865+ dims_tensor = network .add_constant (
866+ (len (dims ),), np .ascontiguousarray (dims , np .int32 )
867+ ).get_output (0 )
868+ else :
869+ assert all (isinstance (d , TRTTensor ) for d in dims )
870+ concat_dims_layer = network .add_concatenation (inputs = dims )
871+ concat_dims_layer .axis = 0
872+ concat_dims_layer .name = f"{ name } _tile_dim"
873+ dims_tensor = concat_dims_layer .get_output (0 )
844874 input_shape_layer = network .add_shape (input_val )
845875 input_shape_layer .name = f"{ name } _slice_input_shape"
846876 slice_shapes_tensor = add_binary_elementwise_layer (
@@ -1880,7 +1910,8 @@ def acc_ops_max_pool1d(
18801910
18811911
18821912@tensorrt_converter (acc_ops .max_pool2d )
1883- def acc_ops_max_pool2d (
1913+ @tensorrt_converter (acc_ops .max_pool3d )
1914+ def acc_ops_max_poolnd (
18841915 network : TRTNetwork ,
18851916 target : Target ,
18861917 args : Tuple [Argument , ...],
@@ -1894,26 +1925,27 @@ def acc_ops_max_pool2d(
18941925 f"MaxPool2d received input { input_val } that is not part "
18951926 "of the TensorRT region!"
18961927 )
1897-
1898- kernel_size = extend_attr_to_tuple (kwargs ["kernel_size" ], 2 )
1899- stride = extend_attr_to_tuple (kwargs ["stride" ], 2 )
1900- padding = extend_attr_to_tuple (kwargs ["padding" ], 2 )
1901- dilation = extend_attr_to_tuple (kwargs ["dilation" ], 2 )
1928+ extend_len = 2 if target == acc_ops . max_pool2d else 3
1929+ kernel_size = extend_attr_to_tuple (kwargs ["kernel_size" ], extend_len )
1930+ stride = extend_attr_to_tuple (kwargs ["stride" ], extend_len )
1931+ padding = extend_attr_to_tuple (kwargs ["padding" ], extend_len )
1932+ dilation = extend_attr_to_tuple (kwargs ["dilation" ], extend_len )
19021933 ceil_mode = kwargs ["ceil_mode" ]
19031934
19041935 if len (stride ) == 0 or stride [0 ] == None :
19051936 stride = kernel_size
19061937
1907- if dilation != (1 , 1 ):
1938+ ones = (1 ,) * extend_len
1939+ if dilation != ones :
19081940 raise RuntimeError (
19091941 f"Only support dilation=(1, 1) for maxpool, but got { dilation } "
19101942 )
19111943
1912- layer = network .add_pooling (
1944+ layer = network .add_pooling_nd (
19131945 input = input_val , type = trt .PoolingType .MAX , window_size = kernel_size
19141946 )
1915- layer .stride = stride
1916- layer .padding = padding
1947+ layer .stride_nd = stride
1948+ layer .padding_nd = padding
19171949 set_layer_name (layer , target , name )
19181950
19191951 if ceil_mode :
@@ -2093,8 +2125,8 @@ def acc_ops_unsqueeze(
20932125 kwargs : Dict [str , Argument ],
20942126 name : str ,
20952127) -> Union [TRTTensor , Sequence [TRTTensor ]]:
2096- input_val = kwargs ["input" ]
2097-
2128+ input_t = kwargs ["input" ]
2129+ input_val = get_trt_tensor ( network , input_t , f" { name } _input_t" )
20982130 if not isinstance (input_val , TRTTensor ):
20992131 raise RuntimeError (
21002132 f"unsqueeze received input { input_val } that is not part "
@@ -2161,8 +2193,9 @@ def acc_ops_topk(
21612193 return layer .get_output (0 ), layer .get_output (1 )
21622194
21632195
2196+ @tensorrt_converter (acc_ops .adaptive_avg_pool3d )
21642197@tensorrt_converter (acc_ops .adaptive_avg_pool2d )
2165- def acc_ops_adaptive_avg_pool2d (
2198+ def acc_ops_adaptive_avg_poolnd (
21662199 network : TRTNetwork ,
21672200 target : Target ,
21682201 args : Tuple [Argument , ...],
@@ -2177,30 +2210,32 @@ def acc_ops_adaptive_avg_pool2d(
21772210 "of the TensorRT region!"
21782211 )
21792212
2180- assert (
2181- input_val .shape [- 1 ] != - 1 and input_val .shape [- 1 ] != - 1
2213+ extend_len = 2 if target == acc_ops .adaptive_avg_pool2d else 3
2214+ assert all (
2215+ input_val .shape [- (i + 1 )] != - 1 for i in range (extend_len )
21822216 ), "AdaptiveAvgPool2d currently doesn't support dynamic shapes for last two dims."
21832217
2184- output_size = cast (Sequence [int ], extend_attr_to_tuple (kwargs ["output_size" ], 2 ))
2185- for input_dim , output_dim in zip (input_val .shape [- 2 :], output_size ):
2218+ output_size = cast (
2219+ Sequence [int ], extend_attr_to_tuple (kwargs ["output_size" ], extend_len )
2220+ )
2221+ for input_dim , output_dim in zip (input_val .shape [- extend_len :], output_size ):
21862222 if input_dim % output_dim != 0 :
21872223 raise RuntimeError (
21882224 "For AdaptiveAvgPool, input dim has to be integer multiple of output dim."
21892225 f"Got input dim { input_dim } , output dim { output_dim } "
21902226 )
21912227
2192- stride = (
2193- input_val .shape [- 2 ] // output_size [0 ],
2194- input_val .shape [- 1 ] // output_size [1 ],
2228+ stride = tuple (
2229+ input_val .shape [- extend_len + i ] // output_size [i ] for i in range (extend_len )
21952230 )
2196- kernel_size = (
2197- input_val .shape [- 2 ] - (output_size [0 ] - 1 ) * stride [0 ],
2198- input_val . shape [ - 1 ] - ( output_size [ 1 ] - 1 ) * stride [ 1 ],
2231+ kernel_size = tuple (
2232+ input_val .shape [- extend_len + i ] - (output_size [i ] - 1 ) * stride [i ]
2233+ for i in range ( extend_len )
21992234 )
2200- layer = network .add_pooling (
2235+ layer = network .add_pooling_nd (
22012236 input = input_val , type = trt .PoolingType .AVERAGE , window_size = kernel_size
22022237 )
2203- layer .stride = stride
2238+ layer .stride_nd = stride
22042239 set_layer_name (layer , target , name )
22052240
22062241 return layer .get_output (0 )
@@ -2781,7 +2816,6 @@ def acc_ops_getitem(
27812816) -> Union [TRTTensor , Sequence [TRTTensor ]]:
27822817 input_val = kwargs ["input" ]
27832818 slices = kwargs ["idx" ]
2784-
27852819 if not isinstance (input_val , TRTTensor ):
27862820 return operator .getitem (input_val , slices ) # type: ignore[arg-type]
27872821
0 commit comments