@@ -129,24 +129,24 @@ nvinfer1::ITensor* applyIdentityOp(ConversionCtx* ctx, nvinfer1::ITensor* tensor
129129}
130130
131131nvinfer1::ITensor* castITensor (ConversionCtx* ctx, nvinfer1::ITensor* tensor, nvinfer1::DataType dtype) {
132- if (tensor->getType () != dtype) {
133- std::ostringstream tensor_id;
134- tensor_id << reinterpret_cast <int *>(tensor);
132+ // No matter whether tensor->getType() == dtype, identity layer is always needed.
133+ // Otherwise will change the input tensor name in aten::to converter by AssociateValueAndTensor function
134+ // When the input of aten::to is network input, will cause error
135+ std::ostringstream tensor_id;
136+ tensor_id << reinterpret_cast <int *>(tensor);
135137
136- auto id_layer = ctx->net ->addIdentity (*tensor);
137- TORCHTRT_CHECK (id_layer, " Unable to create identity layer for ITensor: " << tensor_id.str ());
138- auto casted_tensor = id_layer->getOutput (0 );
139- casted_tensor->setType (dtype);
138+ auto id_layer = ctx->net ->addIdentity (*tensor);
139+ TORCHTRT_CHECK (id_layer, " Unable to create identity layer for ITensor: " << tensor_id.str ());
140+ auto casted_tensor = id_layer->getOutput (0 );
141+ casted_tensor->setType (dtype);
140142
141- LOG_DEBUG (ctx->logger , " Casting ITensor " << tensor_id.str () << " from " << tensor->getType () << " to " << dtype);
143+ LOG_DEBUG (ctx->logger , " Casting ITensor " << tensor_id.str () << " from " << tensor->getType () << " to " << dtype);
144+
145+ std::stringstream ss;
146+ ss << " [Cast ITensor " << tensor_id.str () << " from " << tensor->getType () << " to " << dtype << " ]" ;
147+ id_layer->setName (ss.str ().c_str ());
148+ return casted_tensor;
142149
143- std::stringstream ss;
144- ss << " [Cast ITensor " << tensor_id.str () << " from " << tensor->getType () << " to " << dtype << " ]" ;
145- id_layer->setName (ss.str ().c_str ());
146- return casted_tensor;
147- } else {
148- return tensor;
149- }
150150}
151151
152152nvinfer1::ITensor* tensor_to_const (ConversionCtx* ctx, at::Tensor t, const std::string& name) {
0 commit comments