@@ -187,27 +187,20 @@ def aten_ops_fmod(
187187 return acc_ops_converters .acc_ops_fmod (network , target , None , kwargs_new , name )
188188
189189
190- @tensorrt_converter (torch .ops .aten .mm .default )
191- @tensorrt_converter (torch .ops .aten .addmm .default )
190+ @tensorrt_converter (torch .ops .aten .linear )
192191def aten_ops_linear (
193192 network : TRTNetwork ,
194193 target : Target ,
195194 args : Tuple [Argument , ...],
196195 kwargs : Dict [str , Argument ],
197196 name : str ,
198197) -> Union [TRTTensor , Sequence [TRTTensor ]]:
199- if target == torch .ops .aten .addmm .default :
200- kwargs_new = {
201- "bias" : args [0 ],
202- "input" : args [1 ],
203- "weight" : args [2 ],
204- }
205- elif target == torch .ops .aten .mm .default :
206- kwargs_new = {
207- "bias" : None ,
208- "input" : args [0 ],
209- "weight" : args [1 ],
210- }
198+ kwargs_new = {
199+ "input" : args [0 ],
200+ "weight" : args [1 ],
201+ "bias" : args [2 ],
202+ }
203+
211204 return acc_ops_converters .acc_ops_linear (network , target , None , kwargs_new , name )
212205
213206
@@ -320,3 +313,35 @@ def aten_ops_reshape(
320313 "acc_out_ty" : acc_utils .build_raw_tensor_meta (shape = args [1 ]),
321314 }
322315 return acc_ops_converters .acc_ops_reshape (network , target , None , kwargs_new , name )
316+
317+
318+ @tensorrt_converter (torch .ops .aten .cat .default )
319+ def aten_ops_cat (
320+ network : TRTNetwork ,
321+ target : Target ,
322+ args : Tuple [Argument , ...],
323+ kwargs : Dict [str , Argument ],
324+ name : str ,
325+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
326+ kwargs_new = {
327+ "tensors" : args [0 ],
328+ "dim" : args [1 ],
329+ }
330+ return acc_ops_converters .acc_ops_cat (network , target , None , kwargs_new , name )
331+
332+
333+ @tensorrt_converter (torch .ops .aten .expand .default )
334+ def aten_ops_expand (
335+ network : TRTNetwork ,
336+ target : Target ,
337+ args : Tuple [Argument , ...],
338+ kwargs : Dict [str , Argument ],
339+ name : str ,
340+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
341+ kwargs_new = {
342+ "input" : args [0 ],
343+ "sizes" : args [1 ],
344+ }
345+ return acc_ops_converters .acc_ops_expand_tensor (
346+ network , target , None , kwargs_new , name
347+ )
0 commit comments