@@ -26,7 +26,7 @@ auto cast_registrations TORCHTRT_UNUSED =
2626 } else {
2727 trt_dtype = util::ScalarTypeToTRTDataType (static_cast <at::ScalarType>(output_dtype));
2828 }
29- auto casted_itensor = castITensor (ctx, self, trt_dtype);
29+ auto casted_itensor = castITensor (ctx, self, trt_dtype, util::node_info (n) );
3030 auto output = ctx->AssociateValueAndTensor (n->outputs ()[0 ], casted_itensor);
3131 LOG_DEBUG (" [aten::to.dtype] Output tensor shape: " << output->getDimensions ());
3232
@@ -48,7 +48,7 @@ auto cast_registrations TORCHTRT_UNUSED =
4848 } else {
4949 trt_dtype = util::ScalarTypeToTRTDataType (static_cast <at::ScalarType>(output_dtype));
5050 }
51- auto casted_itensor = castITensor (ctx, self, trt_dtype);
51+ auto casted_itensor = castITensor (ctx, self, trt_dtype, util::node_info (n) );
5252 auto output = ctx->AssociateValueAndTensor (n->outputs ()[0 ], casted_itensor);
5353 LOG_DEBUG (" [aten::to.device] Output tensor shape: " << output->getDimensions ());
5454
@@ -59,7 +59,7 @@ auto cast_registrations TORCHTRT_UNUSED =
5959 [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
6060 auto self = args[0 ].ITensorOrFreeze (ctx);
6161 nvinfer1::DataType other_dtype = args[1 ].ITensorOrFreeze (ctx)->getType ();
62- auto casted_itensor = castITensor (ctx, self, other_dtype);
62+ auto casted_itensor = castITensor (ctx, self, other_dtype, util::node_info (n) );
6363 auto output = ctx->AssociateValueAndTensor (n->outputs ()[0 ], casted_itensor);
6464 LOG_DEBUG (" [aten::to.other] Output tensor shape: " << output->getDimensions ());
6565
@@ -77,7 +77,7 @@ auto cast_registrations TORCHTRT_UNUSED =
7777
7878 auto output_dtype = args[2 ].unwrapToScalar ().to <int64_t >();
7979 auto trt_dtype = util::ScalarTypeToTRTDataType (static_cast <at::ScalarType>(output_dtype));
80- auto casted_itensor = castITensor (ctx, self, trt_dtype);
80+ auto casted_itensor = castITensor (ctx, self, trt_dtype, util::node_info (n) );
8181 auto output = ctx->AssociateValueAndTensor (n->outputs ()[0 ], casted_itensor);
8282 LOG_DEBUG (" [aten::to.prim_Device] Output tensor shape: " << output->getDimensions ());
8383
0 commit comments