|
| 1 | +#include "core/conversion/converters/converters.h" |
| 2 | +#include "core/util/prelude.h" |
| 3 | + |
| 4 | +#include <torch/torch.h> |
| 5 | + |
| 6 | +namespace torch_tensorrt { |
| 7 | +namespace core { |
| 8 | +namespace conversion { |
| 9 | +namespace converters { |
| 10 | +namespace impl { |
| 11 | + |
| 12 | +auto bitwise_not_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern( |
| 13 | + {"aten::bitwise_not(Tensor self) -> Tensor", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { |
| 14 | + auto in = args[0].ITensorOrFreeze(ctx); |
| 15 | + nvinfer1::ILayer* out; |
| 16 | + |
| 17 | + if (in->getType() == nvinfer1::DataType::kINT32) { |
| 18 | + // Integer case, using ~x = -x - 1 |
| 19 | + auto neg_one = torch::tensor({-1}, util::TRTDataTypeToScalarType(in->getType())); |
| 20 | + auto neg_one_const = tensor_to_const(ctx, neg_one); |
| 21 | + auto neg = add_elementwise( |
| 22 | + ctx, |
| 23 | + nvinfer1::ElementWiseOperation::kPROD, |
| 24 | + in, |
| 25 | + neg_one_const, |
| 26 | + util::node_info(n) + std::string("_Negation")); |
| 27 | + TORCHTRT_CHECK(neg, "Unable to create prod layer from node: " << *n); |
| 28 | + out = add_elementwise( |
| 29 | + ctx, |
| 30 | + nvinfer1::ElementWiseOperation::kSUM, |
| 31 | + neg->getOutput(0), |
| 32 | + neg_one_const, |
| 33 | + util::node_info(n) + std::string("_SubOne")); |
| 34 | + TORCHTRT_CHECK(out, "Unable to create sum layer from node: " << *n); |
| 35 | + } else if (in->getType() == nvinfer1::DataType::kBOOL) { |
| 36 | + // Boolean case |
| 37 | + out = ctx->net->addUnary(*in, nvinfer1::UnaryOperation::kNOT); |
| 38 | + TORCHTRT_CHECK(out, "Unable to create logical not layer from node: " << *n); |
| 39 | + } else { |
| 40 | + LOG_ERROR("Input tensor must be 32 bit integer or boolean"); |
| 41 | + return false; |
| 42 | + } |
| 43 | + |
| 44 | + out->setName(util::node_info(n).c_str()); |
| 45 | + auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], out->getOutput(0)); |
| 46 | + LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); |
| 47 | + |
| 48 | + return true; |
| 49 | + }}); |
| 50 | + |
| 51 | +} // namespace impl |
| 52 | +} // namespace converters |
| 53 | +} // namespace conversion |
| 54 | +} // namespace core |
| 55 | +} // namespace torch_tensorrt |
0 commit comments