@@ -19,6 +19,41 @@ namespace conversion {
1919namespace evaluators {
2020namespace {
2121
22+ nvinfer1::ITensor* index_layer (){
23+
24+ }
25+
26+ c10::IValue dynamic_size_layer (ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args){
27+ LOG_DEBUG (" Using dynamic version of aten::size evaluator" );
28+ auto in = args.at (n->input (0 )).ITensorOrFreeze (ctx);
29+ LOG_DEBUG (" Input dimensions: " << in->getDimensions ());
30+ auto shape_layer = ctx->net ->addShape (*in);
31+ auto shape_1d_tensor = shape_layer->getOutput (0 );
32+
33+ if (n->inputs ().size () != 1 ){
34+ auto maxDim = static_cast <int64_t >(in->getDimensions ().nbDims );
35+ auto dim = args.at (n->input (1 )).unwrapToInt ();
36+ // Handle negative axis by refering to nbDims of input Tensor
37+ dim = dim < 0 ? dim + maxDim : dim;
38+ LOG_DEBUG (" Dimension to select: " << dim);
39+
40+ // index to access needs to be an at::Tensor
41+ at::Tensor indices = torch::tensor ({dim}).to (torch::kI32 );
42+ auto indices_out = torch_tensorrt::core::conversion::converters::tensor_to_const (ctx, indices);
43+
44+ auto gather_layer = ctx->net ->addGather (*shape_1d_tensor, *indices_out, 0 );
45+ shape_1d_tensor = gather_layer->getOutput (0 );
46+ }
47+
48+ LOG_DEBUG (" Output tensor shape: " << shape_1d_tensor->getDimensions ());
49+
50+ auto tensor_holder = TensorContainer ();
51+ tensor_holder.hold_tensor (shape_1d_tensor);
52+ auto shape_1d_ivalue = c10::IValue (std::move (c10::make_intrusive<TensorContainer>(tensor_holder)));
53+
54+ return shape_1d_ivalue;
55+ }
56+
2257DEFINE_GENERIC_TWO_INPUT_EVALUATOR (
2358 eq,
2459 " aten::eq" ,
@@ -176,7 +211,7 @@ auto aten_registrations TORCHTRT_UNUSED =
176211 {c10::Symbol::fromQualString (" aten::full_like" ),
177212 // aten::full_like(Tensor self, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None,
178213 // Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> (Tensor)
179- [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
214+ [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
180215 // Override options related to layout and device for TensorRT
181216 auto options = torch::TensorOptions ().layout (torch::kStrided ).device (torch::kCUDA );
182217 auto input_tensor_var = args.at (n->input (0 ));
@@ -262,67 +297,80 @@ auto aten_registrations TORCHTRT_UNUSED =
262297 return static_cast <int64_t >(list.size ());
263298 },
264299 EvalOptions ().validSchemas ({" aten::len.t(t[] a) -> (int)" })})
265- // .evaluator(
266- // {c10::Symbol::fromQualString("aten::size"),
267- // [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
268- // LOG_WARNING("There may be undefined behavior using dynamic shape and aten::size");
269- // auto tensor_var = args.at(n->input(0));
270- // if (n->inputs().size() == 1) {
271- // if (tensor_var.isITensor()) {
272- // auto tensor = tensor_var.ITensor();
273- // return util::toVec(tensor->getDimensions());
274- // } else if (tensor_var.IValue()->isTensor()) {
275- // auto tensor = tensor_var.unwrapToTensor();
276- // return tensor.sizes();
277- // } else if (tensor_var.IValue()->isCustomClass()) {
278- // auto tensor = tensor_var.IValue()->toCustomClass<TensorContainer>()->tensor();
279- // return util::toVec(tensor->getDimensions());
280- // } else {
281- // TORCHTRT_THROW_ERROR("IValue is not some class of Tensor. Found: " << tensor_var.IValue()->type());
282- // }
283- // } else {
284- // auto dim = args.at(n->input(1)).unwrapToInt();
285- // if (tensor_var.isITensor()) {
286- // auto tensor = tensor_var.ITensor();
287- // auto dims = util::toVec(tensor->getDimensions());
288- // auto nbDims = tensor->getDimensions().nbDims;
289- // if (dim < 0) {
290- // dim += nbDims;
291- // }
292- // return dims[dim];
293- // } else if (tensor_var.IValue()->isTensor()) {
294- // auto tensor = tensor_var.unwrapToTensor();
295- // auto nbDims = tensor.sizes().size();
296- // if (dim < 0) {
297- // dim += nbDims;
298- // }
299- // return tensor.sizes()[dim];
300- // } else if (tensor_var.IValue()->isCustomClass()) {
301- // auto tensor = tensor_var.IValue()->toCustomClass<TensorContainer>()->tensor();
302- // auto dims = util::toVec(tensor->getDimensions());
303- // auto nbDims = tensor->getDimensions().nbDims;
304- // if (dim < 0) {
305- // dim += nbDims;
306- // }
307- // return dims[dim];
308- // } else {
309- // TORCHTRT_THROW_ERROR("IValue is not some class of Tensor. Found: " << tensor_var.IValue()->type());
310- // }
311- // }
312- // },
313- // EvalOptions().validSchemas(
314- // {"aten::size(Tensor self) -> (int[])", "aten::size.int(Tensor self, int dim) -> (int)"})})
300+ .evaluator(
301+ {c10::Symbol::fromQualString (" aten::size" ),
302+ [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
303+ auto tensor_var = args.at (n->input (0 ));
304+ if (n->inputs ().size () == 1 ) {
305+ if (tensor_var.isITensor ()) {
306+ auto tensor = tensor_var.ITensor ();
307+ if (ctx->input_is_dynamic ){
308+ return dynamic_size_layer (ctx, n, args);
309+ }
310+ return util::toVec (tensor->getDimensions ());
311+ } else if (tensor_var.IValue ()->isTensor ()) {
312+ auto tensor = tensor_var.unwrapToTensor ();
313+ return tensor.sizes ();
314+ } else if (tensor_var.IValue ()->isCustomClass ()) {
315+ auto tensor = tensor_var.IValue ()->toCustomClass <TensorContainer>()->tensor ();
316+ return util::toVec (tensor->getDimensions ());
317+ } else {
318+ TORCHTRT_THROW_ERROR (" IValue is not some class of Tensor. Found: " << tensor_var.IValue ()->type ());
319+ }
320+ } else {
321+ auto dim = args.at (n->input (1 )).unwrapToInt ();
322+ if (tensor_var.isITensor ()) {
323+ if (ctx->input_is_dynamic ){
324+ return dynamic_size_layer (ctx, n, args);
325+ }
326+ auto tensor = tensor_var.ITensor ();
327+ auto dims = util::toVec (tensor->getDimensions ());
328+ auto nbDims = tensor->getDimensions ().nbDims ;
329+ if (dim < 0 ) {
330+ dim += nbDims;
331+ }
332+ return dims[dim];
333+ } else if (tensor_var.IValue ()->isTensor ()) {
334+ auto tensor = tensor_var.unwrapToTensor ();
335+ auto nbDims = tensor.sizes ().size ();
336+ if (dim < 0 ) {
337+ dim += nbDims;
338+ }
339+ return tensor.sizes ()[dim];
340+ } else if (tensor_var.IValue ()->isCustomClass ()) {
341+ auto tensor = tensor_var.IValue ()->toCustomClass <TensorContainer>()->tensor ();
342+ auto dims = util::toVec (tensor->getDimensions ());
343+ auto nbDims = tensor->getDimensions ().nbDims ;
344+ if (dim < 0 ) {
345+ dim += nbDims;
346+ }
347+ return dims[dim];
348+ } else {
349+ TORCHTRT_THROW_ERROR (" IValue is not some class of Tensor. Found: " << tensor_var.IValue ()->type ());
350+ }
351+ }
352+ },
353+ EvalOptions ().validSchemas (
354+ {" aten::size(Tensor self) -> (int[])" , " aten::size.int(Tensor self, int dim) -> (int)" })})
315355 .evaluator(
316356 {c10::Symbol::fromQualString (" aten::__getitem__" ),
317357 [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
318- auto list = args.at (n->input (0 )). IValue ()-> to <c10::List<c10::IValue>>( );
358+ auto list_input = args.at (n->input (0 ));
319359 auto idx = args.at (n->input (1 )).unwrapToInt ();
360+ if (list_input.isIValue ()){
361+ auto list = args.at (n->input (0 )).IValue ()->to <c10::List<c10::IValue>>();
362+ const int64_t list_size = list.size ();
363+ const int64_t normalized_idx = normalizeIndex (idx, list_size);
364+ TORCHTRT_CHECK (
365+ normalized_idx >= 0 || normalized_idx < list_size, " List index out of range (aten::__getitem__)" );
366+ return list.get (normalized_idx);
367+ } elif (list_input.isITensor ()){
368+ return dynamic_size_layer (ctx, n, args);
369+ }
370+
371+
320372
321- const int64_t list_size = list.size ();
322- const int64_t normalized_idx = normalizeIndex (idx, list_size);
323- TORCHTRT_CHECK (
324- normalized_idx >= 0 || normalized_idx < list_size, " List index out of range (aten::__getitem__)" );
325- return list.get (normalized_idx);
373+
326374 },
327375 EvalOptions ().validSchemas ({
328376 " aten::__getitem__.t(t[](a) list, int idx) -> (t(*))" ,
0 commit comments