33
44import numpy as np
55import tensorrt as trt
6+ import torch
67from torch .fx .node import Target
78from torch_tensorrt .dynamo ._SourceIR import SourceIR
89from torch_tensorrt .dynamo .conversion ._ConversionContext import ConversionContext
@@ -80,23 +81,34 @@ def index(
8081 source_ir : Optional [SourceIR ],
8182 name : str ,
8283 input : TRTTensor ,
83- index : Union [TRTTensor , Sequence [ TRTTensor ]],
84+ index : Sequence [ Union [TRTTensor , np . ndarray , torch . Tensor ]],
8485) -> TRTTensor :
8586 adv_indx_indices = []
8687 tensor_indices = []
87- # _LOGGER.debug(f"The index shape is {index.shape}")
8888 # check if the input is dynamic
8989 dynamic_shape = has_dynamic_shape (input .shape )
90-
90+ # is_numpy is a flag to specify if all the indices are numpy or torchTensor.
91+ # If any is not this flag will be set to False
92+ _LOGGER .debug (
93+ f"Determining whether aten.index constant-index optimization can be invoked"
94+ )
95+ is_numpy = all (
96+ isinstance (ind , (torch .Tensor , np .ndarray )) for ind in index if ind is not None
97+ )
9198 # here we need to check if all the index are broadcastable
9299 # if no, then we need to broadcast
93100 last_index = None
94101 for i , ind in enumerate (index ):
95102 if ind is not None :
96103 _LOGGER .debug (f"Shape of { i } index is { ind .shape } " )
97104 adv_indx_indices .append (i )
98- # torch.nn.parameter.Parameter=> torch.Tensor
99- ind = get_trt_tensor (ctx , ind , name + f"_parameter_to_fp32_tensor_{ i } " )
105+ # torch.nn.parameter.Parameter=> numpy array
106+ # numpy array is kept as numpy
107+ # other cases are kept as TRTTensor
108+ if is_numpy :
109+ ind = to_numpy (ind )
110+ else :
111+ ind = get_trt_tensor (ctx , ind , name + f"_parameter_to_fp32_tensor_{ i } " )
100112 if last_index is not None :
101113 assert broadcastable (
102114 ind , last_index
@@ -110,8 +122,9 @@ def index(
110122 set_layer_name (identity_layer , target , name + "_index_identity" , source_ir )
111123 return identity_layer .get_output (0 )
112124 elif len (tensor_indices ) == 1 :
113- # This case works
114- indices_tensor = tensor_indices [0 ]
125+ indices_tensor = get_trt_tensor (
126+ ctx , tensor_indices [0 ], name + f"_parameter_to_fp32_tensor"
127+ )
115128 index = adv_indx_indices [0 ]
116129 _LOGGER .debug (f"The advanced index indices is { adv_indx_indices } " )
117130 gather_layer = ctx .net .add_gather (input , indices_tensor , index )
@@ -150,6 +163,7 @@ def index(
150163 if i not in adv_indx_indices :
151164 new_order .append (i )
152165 _LOGGER .debug (f"The new transpose order is { new_order } " )
166+
153167 transpose_layer .second_transpose = tuple (new_order )
154168 set_layer_name (transpose_layer , target , name + "_index_transpose" , source_ir )
155169 transpose_tensor = transpose_layer .get_output (0 )
@@ -175,47 +189,58 @@ def index(
175189 concat_tensor = concat_tensor_layer .get_output (0 )
176190
177191 reshape_layer = ctx .net .add_shuffle (transpose_tensor )
178- # check this
179192 reshape_layer .set_input (1 , concat_tensor )
180193 flatten_tensor = reshape_layer .get_output (0 )
194+
181195 _LOGGER .debug (f"The flatten tensor shape is { flatten_tensor .shape } " )
182196
183197 # tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)), ind_i is input indices[i], x_j is the
184198 # // j dimension of input x.
185- multiplier = get_trt_tensor (
186- ctx ,
187- dim_tensor_list [adv_indx_indices [adv_indx_count - 1 ]],
188- name + "_dim_last" ,
189- )
190- cum_adv_index = tensor_indices [adv_indx_count - 1 ]
191- for i in range (adv_indx_count - 2 , - 1 , - 1 ):
192- adv_index = convert_binary_elementwise (
193- ctx ,
194- target ,
195- source_ir ,
196- name + f"_index_intermediate_{ i } " ,
197- trt .ElementWiseOperation .PROD ,
198- multiplier ,
199- tensor_indices [i ],
199+ if is_numpy :
200+ multiplier = input_shape [adv_indx_indices [adv_indx_count - 1 ]]
201+ cum_adv_index = tensor_indices [adv_indx_count - 1 ]
202+ for i in range (adv_indx_count - 2 , - 1 , - 1 ):
203+ adv_index = multiplier * tensor_indices [i ]
204+ cum_adv_index = cum_adv_index + adv_index
205+ multiplier = multiplier * input_shape [adv_indx_indices [i ]]
206+ cum_adv_index = get_trt_tensor (
207+ ctx , cum_adv_index , name + f"_index_sum_intermediate"
200208 )
201- cum_adv_index = convert_binary_elementwise (
202- ctx ,
203- target ,
204- source_ir ,
205- name + f"_index_sum_intermediate_{ i } " ,
206- trt .ElementWiseOperation .SUM ,
207- cum_adv_index ,
208- adv_index ,
209- )
210- multiplier = convert_binary_elementwise (
209+ else :
210+ multiplier = get_trt_tensor (
211211 ctx ,
212- target ,
213- source_ir ,
214- name + f"_index_intermediate_xj_{ i } " ,
215- trt .ElementWiseOperation .PROD ,
216- multiplier ,
217- dim_tensor_list [adv_indx_indices [i ]],
212+ dim_tensor_list [adv_indx_indices [adv_indx_count - 1 ]],
213+ name + "_dim_last" ,
218214 )
215+ cum_adv_index = tensor_indices [adv_indx_count - 1 ]
216+ for i in range (adv_indx_count - 2 , - 1 , - 1 ):
217+ adv_index = convert_binary_elementwise (
218+ ctx ,
219+ target ,
220+ source_ir ,
221+ name + f"_index_intermediate_{ i } " ,
222+ trt .ElementWiseOperation .PROD ,
223+ multiplier ,
224+ tensor_indices [i ],
225+ )
226+ cum_adv_index = convert_binary_elementwise (
227+ ctx ,
228+ target ,
229+ source_ir ,
230+ name + f"_index_sum_intermediate_{ i } " ,
231+ trt .ElementWiseOperation .SUM ,
232+ cum_adv_index ,
233+ adv_index ,
234+ )
235+ multiplier = convert_binary_elementwise (
236+ ctx ,
237+ target ,
238+ source_ir ,
239+ name + f"_index_intermediate_xj_{ i } " ,
240+ trt .ElementWiseOperation .PROD ,
241+ multiplier ,
242+ dim_tensor_list [adv_indx_indices [i ]],
243+ )
219244
220245 gather_layer_element = ctx .net .add_gather (flatten_tensor , cum_adv_index , 0 )
221246 set_layer_name (
0 commit comments