|
| 1 | +#include "core/conversion/evaluators/eval_util.h" |
1 | 2 | #include <ATen/ATen.h> |
2 | 3 | #include "ATen/InitialTensorOptions.h" |
3 | 4 | #include "ATen/core/List.h" |
|
6 | 7 | #include "ATen/core/jit_type.h" |
7 | 8 | #include "c10/util/irange.h" |
8 | 9 | #include "core/util/prelude.h" |
| 10 | +#include "torch/torch.h" |
9 | 11 |
|
10 | 12 | namespace torch_tensorrt { |
11 | 13 | namespace core { |
12 | 14 | namespace conversion { |
13 | 15 | namespace evaluators { |
14 | 16 |
|
| 17 | +nvinfer1::ITensor* index_layer( |
| 18 | + ConversionCtx* ctx, |
| 19 | + const torch::jit::Node* n, |
| 20 | + nvinfer1::ITensor* input_tensor, |
| 21 | + int64_t index) { |
| 22 | + // index to access needs to be an at::Tensor |
| 23 | + at::Tensor indices = torch::tensor({index}).to(torch::kI32); |
| 24 | + auto indices_out = converters::tensor_to_const(ctx, indices); |
| 25 | + |
| 26 | + auto gather_layer = ctx->net->addGather(*input_tensor, *indices_out, 0); |
| 27 | + TORCHTRT_CHECK(gather_layer, "Unable to create gather layer from node: " << *n); |
| 28 | + auto indexed_tensor = gather_layer->getOutput(0); |
| 29 | + return indexed_tensor; |
| 30 | +} |
| 31 | + |
| 32 | +c10::IValue dynamic_size_layer(ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) { |
| 33 | + LOG_DEBUG("Using dynamic version of aten::size evaluator"); |
| 34 | + auto in = args.at(n->input(0)).ITensorOrFreeze(ctx); |
| 35 | + LOG_DEBUG("Input dimensions: " << in->getDimensions()); |
| 36 | + auto shape_layer = ctx->net->addShape(*in); |
| 37 | + TORCHTRT_CHECK(shape_layer, "Unable to create shape layer from node: " << *n); |
| 38 | + auto shape_1d_tensor = shape_layer->getOutput(0); |
| 39 | + |
| 40 | + if (n->inputs().size() != 1) { |
| 41 | + auto maxDim = static_cast<int64_t>(in->getDimensions().nbDims); |
| 42 | + auto dim = args.at(n->input(1)).unwrapToInt(); |
| 43 | + // Handle negative axis by refering to nbDims of input Tensor |
| 44 | + dim = dim < 0 ? dim + maxDim : dim; |
| 45 | + LOG_DEBUG("Dimension to select: " << dim); |
| 46 | + shape_1d_tensor = index_layer(ctx, n, shape_1d_tensor, dim); |
| 47 | + } |
| 48 | + |
| 49 | + LOG_DEBUG("Output tensor shape: " << shape_1d_tensor->getDimensions()); |
| 50 | + |
| 51 | + auto tensor_holder = TensorContainer(); |
| 52 | + tensor_holder.hold_tensor(shape_1d_tensor); |
| 53 | + auto shape_1d_ivalue = c10::IValue(std::move(c10::make_intrusive<TensorContainer>(tensor_holder))); |
| 54 | + |
| 55 | + return shape_1d_ivalue; |
| 56 | +} |
| 57 | + |
15 | 58 | int64_t normalizeIndex(int64_t idx, int64_t list_size) { |
16 | 59 | if (idx < 0) { |
17 | 60 | // Handle negative indexing |
@@ -128,7 +171,7 @@ void checkSequenceSize(int64_t n, int64_t dim, int64_t seq_size) { |
128 | 171 | } |
129 | 172 |
|
130 | 173 | // TODO: Conditionally enable truncation based on user setting |
131 | | -at::Tensor scalar_to_tensor(const at::Scalar& s, const at::Device device = at::kCPU) { |
| 174 | +at::Tensor scalar_to_tensor(const at::Scalar& s, const at::Device device) { |
132 | 175 | // This function is basically same with the one in |
133 | 176 | // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/ScalarOps.h, what different here is that Int and Float |
134 | 177 | // won't be upgraded to kDouble or kLong since we don't support these 2 types in conversion |
|
0 commit comments