@@ -19,15 +19,23 @@ namespace conversion {
1919namespace evaluators {
2020namespace {
2121
22- nvinfer1::ITensor* index_layer (){
23-
22+ nvinfer1::ITensor* index_layer (ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor* input_tensor, int64_t index){
23+ // index to access needs to be an at::Tensor
24+ at::Tensor indices = torch::tensor ({index}).to (torch::kI32 );
25+ auto indices_out = torch_tensorrt::core::conversion::converters::tensor_to_const (ctx, indices);
26+
27+ auto gather_layer = ctx->net ->addGather (*input_tensor, *indices_out, 0 );
28+ TORCHTRT_CHECK (gather_layer, " Unable to create gather layer from node: " << *n);
29+ auto indexed_tensor = gather_layer->getOutput (0 );
30+ return indexed_tensor;
2431}
2532
2633c10::IValue dynamic_size_layer (ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args){
2734 LOG_DEBUG (" Using dynamic version of aten::size evaluator" );
2835 auto in = args.at (n->input (0 )).ITensorOrFreeze (ctx);
2936 LOG_DEBUG (" Input dimensions: " << in->getDimensions ());
3037 auto shape_layer = ctx->net ->addShape (*in);
38+ TORCHTRT_CHECK (shape_layer, " Unable to create shape layer from node: " << *n);
3139 auto shape_1d_tensor = shape_layer->getOutput (0 );
3240
3341 if (n->inputs ().size () != 1 ){
@@ -36,15 +44,9 @@ c10::IValue dynamic_size_layer(ConversionCtx* ctx, const torch::jit::Node* n, kw
3644 // Handle negative axis by refering to nbDims of input Tensor
3745 dim = dim < 0 ? dim + maxDim : dim;
3846 LOG_DEBUG (" Dimension to select: " << dim);
39-
40- // index to access needs to be an at::Tensor
41- at::Tensor indices = torch::tensor ({dim}).to (torch::kI32 );
42- auto indices_out = torch_tensorrt::core::conversion::converters::tensor_to_const (ctx, indices);
43-
44- auto gather_layer = ctx->net ->addGather (*shape_1d_tensor, *indices_out, 0 );
45- shape_1d_tensor = gather_layer->getOutput (0 );
47+ shape_1d_tensor = index_layer (ctx, n, shape_1d_tensor, dim);
4648 }
47-
49+
4850 LOG_DEBUG (" Output tensor shape: " << shape_1d_tensor->getDimensions ());
4951
5052 auto tensor_holder = TensorContainer ();
@@ -364,13 +366,13 @@ auto aten_registrations TORCHTRT_UNUSED =
364366 TORCHTRT_CHECK (
365367 normalized_idx >= 0 || normalized_idx < list_size, " List index out of range (aten::__getitem__)" );
366368 return list.get (normalized_idx);
367- } elif (list_input.isITensor ()){
368- return dynamic_size_layer (ctx, n, args);
369+ } else if (list_input.isITensor ()){
370+ auto indexed_tensor = index_layer (ctx, n, list_input.ITensorOrFreeze (ctx), idx);
371+ auto tensor_holder = TensorContainer ();
372+ tensor_holder.hold_tensor (indexed_tensor);
373+ auto indexed_ivalue = c10::IValue (std::move (c10::make_intrusive<TensorContainer>(tensor_holder)));
374+ return indexed_ivalue;
369375 }
370-
371-
372-
373-
374376 },
375377 EvalOptions ().validSchemas ({
376378 " aten::__getitem__.t(t[](a) list, int idx) -> (t(*))" ,
0 commit comments