Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <numeric>
#include <optional>
#include <random>
#include <type_traits>

#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"

Expand Down Expand Up @@ -91,6 +92,16 @@ class ConvertAtenUnaryOp : public OpConversionPattern<AtenOpT> {

self = tosa::tosaCastTensorToType(rewriter, self, outType).value();

if constexpr (std::is_same_v<AtenOpT, AtenBitwiseNotOp>) {
if (auto intTy = dyn_cast<IntegerType>(outType.getElementType())) {
if (intTy.getWidth() == 1) {
rewriter.replaceOpWithNewOp<tosa::LogicalNotOp>(op, outType, self);
return success();
}
}
// otherwise fall through to standard emission
}

rewriter.replaceOpWithNewOp<TosaOpT>(op, outType, self);

return success();
Expand Down
23 changes: 23 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -5175,6 +5175,29 @@ def ElementwiseBitwiseNotInt32Module_basic(module, tu: TestUtils):
# ==============================================================================


class ElementwiseBitwiseNotBoolModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1, -1], torch.bool, True),
]
)
def forward(self, x):
return torch.bitwise_not(x)


@register_test_case(module_factory=lambda: ElementwiseBitwiseNotBoolModule())
def ElementwiseBitwiseNotBoolModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, low=0, high=2).to(torch.bool))


# ==============================================================================


class ElementwiseSubTensorInt8Module(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
14 changes: 14 additions & 0 deletions test/Conversion/TorchToTosa/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,20 @@ func.func @torch.aten.bitwise_not$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !to

// -----

// CHECK-LABEL: func.func @torch.aten.bitwise_not$bool(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3],i1>) -> !torch.vtensor<[2,3],i1> {
// CHECK: %[[ARG_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[2,3],i1> -> tensor<2x3xi1>
// CHECK: %[[RESULT_BUILTIN:.*]] = tosa.logical_not %[[ARG_BUILTIN]] : (tensor<2x3xi1>) -> tensor<2x3xi1>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<2x3xi1> -> !torch.vtensor<[2,3],i1>
// CHECK: return %[[RESULT]] : !torch.vtensor<[2,3],i1>
// CHECK: }
func.func @torch.aten.bitwise_not$bool(%arg0: !torch.vtensor<[2,3],i1>) -> !torch.vtensor<[2,3],i1> {
%0 = torch.aten.bitwise_not %arg0 : !torch.vtensor<[2,3],i1> -> !torch.vtensor<[2,3],i1>
return %0 : !torch.vtensor<[2,3],i1>
}

// -----

// CHECK-LABEL: func.func @torch.aten.ceil$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
Expand Down
Loading