@@ -2483,7 +2483,10 @@ def acc_ops_where(
24832483
24842484 if type (x_t ) != TRTTensor :
24852485 if x_shape != output_shape :
2486- x_t .expand (output_shape )
2486+ # special case where 1 element in x_t
2487+ if len (x_t .shape ) == 0 :
2488+ x_t = x_t .unsqueeze (0 )
2489+ x_t = x_t .expand (output_shape )
24872490 x_val = get_trt_tensor (network , x_t , f"{ name } _x" )
24882491 else :
24892492 x_val = x_t
@@ -2498,7 +2501,10 @@ def acc_ops_where(
24982501
24992502 if type (y_t ) != TRTTensor :
25002503 if y_shape != output_shape :
2501- y_t .expand (output_shape )
2504+ # special case where 1 element in y_t
2505+ if len (y_t .shape ) == 0 :
2506+ y_t = y_t .unsqueeze (0 )
2507+ y_t = y_t .expand (output_shape )
25022508 y_val = get_trt_tensor (network , y_t , f"{ name } _y" )
25032509 else :
25042510 y_val = y_t
@@ -2912,16 +2918,20 @@ def acc_ops_cat(
29122918 name : str ,
29132919) -> Union [TRTTensor , Sequence [TRTTensor ]]:
29142920 tensors = kwargs ["tensors" ]
2921+ dim = kwargs ["dim" ]
29152922
29162923 if any (not isinstance (t , TRTTensor ) for t in tensors ): # type: ignore[union-attr]
29172924 raise RuntimeError (
29182925 f"cat received inputs { tensors } that is not part " "of the TensorRT region!"
29192926 )
2920-
29212927 layer = network .add_concatenation (inputs = tensors )
2922- layer .axis = cast (int , kwargs ["dim" ]) - (
2923- 1 if network .has_implicit_batch_dimension else 0
2924- )
2928+ if dim < 0 :
2929+ if network .has_implicit_batch_dimension :
2930+ dim = len (tensors [0 ].shape ) + 1 + dim
2931+ else :
2932+ dim = len (tensors [0 ].shape ) + dim
2933+
2934+ layer .axis = dim - (1 if network .has_implicit_batch_dimension else 0 )
29252935 set_layer_name (layer , target , name )
29262936 return layer .get_output (0 )
29272937
@@ -3477,3 +3487,129 @@ def acc_ops_interpolate(
34773487
34783488 set_layer_name (layer , target , name )
34793489 return layer .get_output (0 )
3490+
3491+
3492+ @tensorrt_converter (acc_ops .new_ones )
3493+ def acc_ops_new_ones (
3494+ network : TRTNetwork ,
3495+ target : Target ,
3496+ args : Tuple [Argument , ...],
3497+ kwargs : Dict [str , Argument ],
3498+ name : str ,
3499+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
3500+ input_val = kwargs ["input" ]
3501+ size_val = kwargs ["size" ]
3502+ dtype_val = kwargs .get ("dtype" )
3503+ if dtype_val is None :
3504+ dtype_val = input_val .dtype
3505+ dtype_val = torch_dtype_from_trt (dtype_val )
3506+
3507+ device_val = kwargs .get ("device" )
3508+ assert (
3509+ device_val == "cuda" or device_val == None
3510+ ), f"device is not `cuda` but { device_val } "
3511+
3512+ weight = torch .ones (size_val , dtype = dtype_val )
3513+ return get_trt_tensor (network , weight , f"{ name } _weight" )
3514+
3515+
3516+ @tensorrt_converter (acc_ops .new_empty )
3517+ def acc_ops_new_empty (
3518+ network : TRTNetwork ,
3519+ target : Target ,
3520+ args : Tuple [Argument , ...],
3521+ kwargs : Dict [str , Argument ],
3522+ name : str ,
3523+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
3524+ input_val = kwargs ["input" ]
3525+ size_val = kwargs ["size" ]
3526+ dtype_val = kwargs .get ("dtype" )
3527+ if dtype_val is None :
3528+ dtype_val = input_val .dtype
3529+ dtype_val = torch_dtype_from_trt (dtype_val )
3530+
3531+ device_val = kwargs .get ("device" )
3532+ assert (
3533+ device_val == "cuda" or device_val == None
3534+ ), f"device is not `cuda` but { device_val } "
3535+
3536+ weight = torch .zeros (size_val , dtype = dtype_val )
3537+ return get_trt_tensor (network , weight , f"{ name } _weight" )
3538+
3539+
3540+ @tensorrt_converter (acc_ops .einsum )
3541+ def acc_ops_einsum (
3542+ network : TRTNetwork ,
3543+ target : Target ,
3544+ args : Tuple [Argument , ...],
3545+ kwargs : Dict [str , Argument ],
3546+ name : str ,
3547+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
3548+ input_val = list (kwargs ["operands" ])
3549+ equation = kwargs ["equation" ]
3550+ assert type (equation ) is str , "equation type is not str"
3551+ const_flag = False
3552+ for i , input_source in enumerate (input_val ):
3553+ if type (input_source ) == torch .Tensor :
3554+ # const change to TRTensor always output with dtype FLOAT even though stored memory is other type
3555+ # so we cast to float first. And we need other inputs to be the same float type
3556+ input_source = input_source .to (torch .float )
3557+ const_flag = True
3558+ input_val [i ] = get_trt_tensor (network , input_source , name + f"_input_source{ i } " )
3559+
3560+ if const_flag :
3561+ for i , input_source in enumerate (input_val ):
3562+ if input_source .dtype != trt .float32 :
3563+ input_val [i ] = type_cast (
3564+ network , target , f"{ name } _input_cast{ i } " , input_source , trt .float32
3565+ )
3566+ einsum_layer = network .add_einsum (inputs = input_val , equation = equation )
3567+ return einsum_layer .get_output (0 )
3568+
3569+
3570+ @tensorrt_converter (acc_ops .as_strided )
3571+ def acc_ops_as_strided (
3572+ network : TRTNetwork ,
3573+ target : Target ,
3574+ args : Tuple [Argument , ...],
3575+ kwargs : Dict [str , Argument ],
3576+ name : str ,
3577+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
3578+ input_val = kwargs ["input" ]
3579+ size = kwargs ["size" ]
3580+ stride = kwargs ["stride" ]
3581+ offset = kwargs .get ("storage_offset" )
3582+ if offset == None :
3583+ offset = 0
3584+
3585+ # convert to 1d vector
3586+ new_kwargs = {}
3587+ new_kwargs ["input" ] = kwargs ["input" ]
3588+ new_kwargs ["start_dim" ] = 0
3589+ new_kwargs ["end_dim" ] = - 1
3590+ flatten_output = acc_ops_flatten (network , target , [], new_kwargs , name + "_flatten" )
3591+ # use gather to collect output from 1d flatten_output
3592+ rank = len (size )
3593+ assert len (size ) == len (stride ), "size and stride shapes are not the same"
3594+
3595+ def nested (rank , size , stride , current , dim , indices ):
3596+ if dim == rank :
3597+ indices .append (current )
3598+ return
3599+ for i in range (size [dim ]):
3600+ current = current + stride [dim ] * i
3601+ nested (rank , size , stride , current , dim + 1 , indices )
3602+ current = current - stride [dim ] * i
3603+
3604+ indices = []
3605+ nested (rank , size , stride , 0 , 0 , indices )
3606+ indices = torch .tensor (indices , dtype = torch .int )
3607+ indices = indices + offset
3608+ indices_tensor = get_trt_tensor (network , indices , name + "_indices_tensor" )
3609+ gather_layer = network .add_gather (flatten_output , indices_tensor , axis = 0 )
3610+ # resize the output to match size
3611+ shuffle_layer = network .add_shuffle (gather_layer .get_output (0 ))
3612+ set_layer_name (shuffle_layer , target , name + "_shuffle" )
3613+ shuffle_layer .reshape_dims = tuple (size )
3614+
3615+ return shuffle_layer .get_output (0 )
0 commit comments