@@ -18,16 +18,22 @@ auto bitwisenot TORCHTRT_UNUSED =
1818 nvinfer1::ILayer* out;
1919
2020 if (in->getType () == nvinfer1::DataType::kINT32 ) {
21- // Integer case
21+ // Integer case, using ~x = -x - 1
2222 auto neg_one = torch::tensor ({-1 }, util::TRTDataTypeToScalarType (in->getType ()));
2323 auto neg_one_const = tensor_to_const (ctx, neg_one);
2424 auto neg = add_elementwise (
25- ctx, nvinfer1::ElementWiseOperation::kPROD , in,
26- neg_one_const, util::node_info (n) + std::string (" _Negation" ));
25+ ctx,
26+ nvinfer1::ElementWiseOperation::kPROD ,
27+ in,
28+ neg_one_const,
29+ util::node_info (n) + std::string (" _Negation" ));
2730 TORCHTRT_CHECK (neg, " Unable to create prod layer from node: " << *n);
2831 out = add_elementwise (
29- ctx, nvinfer1::ElementWiseOperation::kSUM , neg->getOutput (0 ),
30- neg_one_const, util::node_info (n) + std::string (" _SubOne" ));
32+ ctx,
33+ nvinfer1::ElementWiseOperation::kSUM ,
34+ neg->getOutput (0 ),
35+ neg_one_const,
36+ util::node_info (n) + std::string (" _SubOne" ));
3137 TORCHTRT_CHECK (out, " Unable to create sum layer from node: " << *n);
3238 } else if (in->getType () == nvinfer1::DataType::kBOOL ) {
3339 // Boolean case
@@ -39,8 +45,7 @@ auto bitwisenot TORCHTRT_UNUSED =
3945 }
4046
4147 out->setName (util::node_info (n).c_str ());
42- auto out_tensor = ctx->AssociateValueAndTensor (n->outputs ()[0 ],
43- out->getOutput (0 ));
48+ auto out_tensor = ctx->AssociateValueAndTensor (n->outputs ()[0 ], out->getOutput (0 ));
4449 LOG_DEBUG (" Output tensor shape: " << out_tensor->getDimensions ());
4550
4651 return true ;
0 commit comments