@@ -129,24 +129,24 @@ nvinfer1::ITensor* applyIdentityOp(ConversionCtx* ctx, nvinfer1::ITensor* tensor
129129}
130130
131131nvinfer1::ITensor* castITensor (ConversionCtx* ctx, nvinfer1::ITensor* tensor, nvinfer1::DataType dtype) {
132- // No matter whether tensor->getType() == dtype, identity layer is always needed.
133- // Otherwise will change the input tensor name in aten::to converter by AssociateValueAndTensor function
134- // When the input of aten::to is network input, will cause error
135- std::ostringstream tensor_id;
136- tensor_id << reinterpret_cast <int *>(tensor);
137-
138- auto id_layer = ctx->net ->addIdentity (*tensor);
139- TORCHTRT_CHECK (id_layer, " Unable to create identity layer for ITensor: " << tensor_id.str ());
140- auto casted_tensor = id_layer->getOutput (0 );
141- casted_tensor->setType (dtype);
132+ if (tensor->getType () != dtype) {
133+ std::ostringstream tensor_id;
134+ tensor_id << reinterpret_cast <int *>(tensor);
142135
143- LOG_DEBUG (ctx->logger , " Casting ITensor " << tensor_id.str () << " from " << tensor->getType () << " to " << dtype);
136+ auto id_layer = ctx->net ->addIdentity (*tensor);
137+ TORCHTRT_CHECK (id_layer, " Unable to create identity layer for ITensor: " << tensor_id.str ());
138+ auto casted_tensor = id_layer->getOutput (0 );
139+ casted_tensor->setType (dtype);
144140
145- std::stringstream ss;
146- ss << " [Cast ITensor " << tensor_id.str () << " from " << tensor->getType () << " to " << dtype << " ]" ;
147- id_layer->setName (ss.str ().c_str ());
148- return casted_tensor;
141+ LOG_DEBUG (ctx->logger , " Casting ITensor " << tensor_id.str () << " from " << tensor->getType () << " to " << dtype);
149142
143+ std::stringstream ss;
144+ ss << " [Cast ITensor " << tensor_id.str () << " from " << tensor->getType () << " to " << dtype << " ]" ;
145+ id_layer->setName (ss.str ().c_str ());
146+ return casted_tensor;
147+ } else {
148+ return tensor;
149+ }
150150}
151151
152152nvinfer1::ITensor* tensor_to_const (ConversionCtx* ctx, at::Tensor t, const std::string& name) {
0 commit comments