@@ -9,48 +9,44 @@ namespace conversion {
99namespace converters {
1010namespace impl {
1111
12-
13- auto bitwisenot TORCHTRT_UNUSED =
14- RegisterNodeConversionPatterns ()
15- .pattern({" aten::bitwise_not(Tensor self) -> Tensor" ,
16- [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
17- auto in = args[0 ].ITensorOrFreeze (ctx);
18- nvinfer1::ILayer* out;
19-
20- if (in->getType () == nvinfer1::DataType::kINT32 ) {
21- // Integer case, using ~x = -x - 1
22- auto neg_one = torch::tensor ({-1 }, util::TRTDataTypeToScalarType (in->getType ()));
23- auto neg_one_const = tensor_to_const (ctx, neg_one);
24- auto neg = add_elementwise (
25- ctx,
26- nvinfer1::ElementWiseOperation::kPROD ,
27- in,
28- neg_one_const,
29- util::node_info (n) + std::string (" _Negation" ));
30- TORCHTRT_CHECK (neg, " Unable to create prod layer from node: " << *n);
31- out = add_elementwise (
32- ctx,
33- nvinfer1::ElementWiseOperation::kSUM ,
34- neg->getOutput (0 ),
35- neg_one_const,
36- util::node_info (n) + std::string (" _SubOne" ));
37- TORCHTRT_CHECK (out, " Unable to create sum layer from node: " << *n);
38- } else if (in->getType () == nvinfer1::DataType::kBOOL ) {
39- // Boolean case
40- out = ctx->net ->addUnary (*in, nvinfer1::UnaryOperation::kNOT );
41- TORCHTRT_CHECK (out, " Unable to create logical not layer from node: " << *n);
42- } else {
43- LOG_ERROR (" Input tensor must be 32 bit integer or boolean" );
44- return false ;
45- }
46-
47- out->setName (util::node_info (n).c_str ());
48- auto out_tensor = ctx->AssociateValueAndTensor (n->outputs ()[0 ], out->getOutput (0 ));
49- LOG_DEBUG (" Output tensor shape: " << out_tensor->getDimensions ());
50-
51- return true ;
52- }});
53-
12+ auto bitwise_not_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern(
13+ {" aten::bitwise_not(Tensor self) -> Tensor" , [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
14+ auto in = args[0 ].ITensorOrFreeze (ctx);
15+ nvinfer1::ILayer* out;
16+
17+ if (in->getType () == nvinfer1::DataType::kINT32 ) {
18+ // Integer case, using ~x = -x - 1
19+ auto neg_one = torch::tensor ({-1 }, util::TRTDataTypeToScalarType (in->getType ()));
20+ auto neg_one_const = tensor_to_const (ctx, neg_one);
21+ auto neg = add_elementwise (
22+ ctx,
23+ nvinfer1::ElementWiseOperation::kPROD ,
24+ in,
25+ neg_one_const,
26+ util::node_info (n) + std::string (" _Negation" ));
27+ TORCHTRT_CHECK (neg, " Unable to create prod layer from node: " << *n);
28+ out = add_elementwise (
29+ ctx,
30+ nvinfer1::ElementWiseOperation::kSUM ,
31+ neg->getOutput (0 ),
32+ neg_one_const,
33+ util::node_info (n) + std::string (" _SubOne" ));
34+ TORCHTRT_CHECK (out, " Unable to create sum layer from node: " << *n);
35+ } else if (in->getType () == nvinfer1::DataType::kBOOL ) {
36+ // Boolean case
37+ out = ctx->net ->addUnary (*in, nvinfer1::UnaryOperation::kNOT );
38+ TORCHTRT_CHECK (out, " Unable to create logical not layer from node: " << *n);
39+ } else {
40+ LOG_ERROR (" Input tensor must be 32 bit integer or boolean" );
41+ return false ;
42+ }
43+
44+ out->setName (util::node_info (n).c_str ());
45+ auto out_tensor = ctx->AssociateValueAndTensor (n->outputs ()[0 ], out->getOutput (0 ));
46+ LOG_DEBUG (" Output tensor shape: " << out_tensor->getDimensions ());
47+
48+ return true ;
49+ }});
5450
5551} // namespace impl
5652} // namespace converters
0 commit comments