From cf23a4ca56710088b3a6573288db529f87a99f26 Mon Sep 17 00:00:00 2001 From: Vitalii Shutov Date: Mon, 3 Nov 2025 15:41:07 +0000 Subject: [PATCH] [TOSA] Lower boolean aten.bitwise_not to logical_not - Fix TorchToTosa's shared unary pattern so AtenBitwiseNotOp with i1 outputs emits tosa.logical_not instead of the tosa.bitwise_not. - Add a regression in test/Conversion/TorchToTosa/basic.mlir that checks the lowering path for a bool tensor. - Add a regression end-to-end test for bitwise_not with boolean. Change-Id: I1742adcd150d3a1ddb33b9dd58001d71f3e08860 --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 11 +++++++++ .../test_suite/elementwise.py | 23 +++++++++++++++++++ test/Conversion/TorchToTosa/basic.mlir | 14 +++++++++++ 3 files changed, 48 insertions(+) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 0bc93f711ad6..f0465aa801f8 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -28,6 +28,7 @@ #include #include #include +#include #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" @@ -91,6 +92,16 @@ class ConvertAtenUnaryOp : public OpConversionPattern { self = tosa::tosaCastTensorToType(rewriter, self, outType).value(); + if constexpr (std::is_same_v) { + if (auto intTy = dyn_cast(outType.getElementType())) { + if (intTy.getWidth() == 1) { + rewriter.replaceOpWithNewOp(op, outType, self); + return success(); + } + } + // otherwise fall through to standard emission + } + rewriter.replaceOpWithNewOp(op, outType, self); return success(); diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index c00e48f39e88..40b6b1b873af 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -5175,6 +5175,29 @@ def ElementwiseBitwiseNotInt32Module_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseBitwiseNotBoolModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.bool, True), + ] + ) + def forward(self, x): + return torch.bitwise_not(x) + + +@register_test_case(module_factory=lambda: ElementwiseBitwiseNotBoolModule()) +def ElementwiseBitwiseNotBoolModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=0, high=2).to(torch.bool)) + + +# ============================================================================== + + class ElementwiseSubTensorInt8Module(torch.nn.Module): def __init__(self): super().__init__() diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index d100fe9dcfde..f0f1f4b9ed6c 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -135,6 +135,20 @@ func.func @torch.aten.bitwise_not$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !to // ----- +// CHECK-LABEL: func.func @torch.aten.bitwise_not$bool( +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3],i1>) -> !torch.vtensor<[2,3],i1> { +// CHECK: %[[ARG_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[2,3],i1> -> tensor<2x3xi1> +// CHECK: %[[RESULT_BUILTIN:.*]] = tosa.logical_not %[[ARG_BUILTIN]] : (tensor<2x3xi1>) -> tensor<2x3xi1> +// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<2x3xi1> -> !torch.vtensor<[2,3],i1> +// CHECK: return %[[RESULT]] : !torch.vtensor<[2,3],i1> +// CHECK: } +func.func @torch.aten.bitwise_not$bool(%arg0: !torch.vtensor<[2,3],i1>) -> !torch.vtensor<[2,3],i1> { + %0 = torch.aten.bitwise_not %arg0 : !torch.vtensor<[2,3],i1> -> !torch.vtensor<[2,3],i1> + return %0 : !torch.vtensor<[2,3],i1> +} + +// ----- + // CHECK-LABEL: func.func @torch.aten.ceil$basic( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor