From 175d2a21926bd516a71fdee666f469e1b05c2365 Mon Sep 17 00:00:00 2001 From: martino Date: Wed, 29 Oct 2025 15:26:40 +0100 Subject: [PATCH] [Torch] Fold aten rounding ops on splat constants. This commit teaches the folding methods of `AtenFloor`, `AtenCeil`, `AtenRound`, and `AtenTruc` to constant-fold roundings when the operand is a splat `DenseElementsAttr`. --- lib/Dialect/Torch/IR/TorchOps.cpp | 81 ++++++++++++++++--- .../test_suite/elementwise.py | 64 +++++++++++++++ test/Dialect/Torch/canonicalize.mlir | 58 +++++++++++++ 3 files changed, 191 insertions(+), 12 deletions(-) diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index a4888a218fae..df36697ca1fa 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -234,6 +234,28 @@ static Value getScalarFloatValue(Value input, Location loc, return nullptr; } +// Common helper for splat-only rounding-based folders. +static OpFoldResult foldSplatRounding(ValueTensorType resultType, + Attribute selfAttr, + APFloat::roundingMode mode) { + auto elems = dyn_cast_or_null(selfAttr); + if (!elems || !elems.isSplat()) + return {}; + + if (!isa(resultType.getDtype())) + return {}; + + auto outShaped = resultType.toBuiltinTensor(); + if (!outShaped.hasStaticShape()) + return {}; + + APFloat v = elems.getSplatValue(); + // NaNs and infs are dealt with consistently with torch, so side-effects + // can be discarded. + (void)v.roundToIntegral(mode); + return DenseElementsAttr::get(outShaped, v); +} + //===----------------------------------------------------------------------===// // MethodOp //===----------------------------------------------------------------------===// @@ -2064,10 +2086,19 @@ OpFoldResult AtenLogOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenFloorOp::fold(FoldAdaptor adaptor) { auto resultType = dyn_cast(getType()); - if (resultType && resultType.hasDtype() && - isa(resultType.getDtype())) { + + if (!resultType || !resultType.hasDtype()) + return {}; + + // No-op if the result is int, fold. + if (isa(resultType.getDtype())) return getSelf(); - } + + // Fold float splats. + if (auto res = foldSplatRounding(resultType, /*selfAttr*/ adaptor.getSelf(), + /*mode*/ APFloat::rmTowardNegative)) + return res; + return {}; } @@ -2077,10 +2108,19 @@ OpFoldResult AtenFloorOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenCeilOp::fold(FoldAdaptor adaptor) { auto resultType = dyn_cast(getType()); - if (resultType && resultType.hasDtype() && - isa(resultType.getDtype())) { + + if (!resultType || !resultType.hasDtype()) + return {}; + + // No-op if the result is int, fold. + if (isa(resultType.getDtype())) return getSelf(); - } + + // Fold float splats. + if (auto res = foldSplatRounding(resultType, /*selfAttr*/ adaptor.getSelf(), + /*mode*/ APFloat::rmTowardPositive)) + return res; + return {}; } @@ -2103,10 +2143,18 @@ OpFoldResult AtenRoundDecimalsOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenRoundOp::fold(FoldAdaptor adaptor) { auto resultType = dyn_cast(getType()); - if (resultType && resultType.hasDtype() && - isa(resultType.getDtype())) { + if (!resultType || !resultType.hasDtype()) + return {}; + + // No-op if the result is int, fold. + if (isa(resultType.getDtype())) return getSelf(); - } + + // Fold float splats. + if (auto res = foldSplatRounding(resultType, /*selfAttr*/ adaptor.getSelf(), + /*mode*/ APFloat::rmNearestTiesToEven)) + return res; + return {}; } @@ -2116,10 +2164,19 @@ OpFoldResult AtenRoundOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenTruncOp::fold(FoldAdaptor adaptor) { auto resultType = dyn_cast(getType()); - if (resultType && resultType.hasDtype() && - isa(resultType.getDtype())) { + + if (!resultType || !resultType.hasDtype()) + return {}; + + // No-op if the result is int, fold. + if (isa(resultType.getDtype())) return getSelf(); - } + + // Fold float splats. + if (auto res = foldSplatRounding(resultType, /*selfAttr*/ adaptor.getSelf(), + /*mode*/ APFloat::rmTowardZero)) + return res; + return {}; } diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index c00e48f39e88..084fd31368c2 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -6844,6 +6844,70 @@ def AtenRoundFloatDecimalsModule_basic(module, tu: TestUtils): module.forward(tu.rand(5, 5, low=-3.0, high=3.0)) +class AtenRoundNegFloatHalfToEvenSplatModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.const = torch.tensor([-1.5, -1.5], dtype=torch.float32) + + @export + @annotate_args([None]) + def forward(self): + return torch.ops.aten.round(self.const) + + +@register_test_case(module_factory=lambda: AtenRoundNegFloatHalfToEvenSplatModule()) +def AtenRoundNegFloatHalfToEvenSplatModule_basic(module, tu: TestUtils): + module.forward() + + +class AtenRoundPosFloatHalfToEvenSplatModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.const = torch.tensor([1.5, 1.5], dtype=torch.float32) + + @export + @annotate_args([None]) + def forward(self): + return torch.ops.aten.round(self.const) + + +@register_test_case(module_factory=lambda: AtenRoundPosFloatHalfToEvenSplatModule()) +def AtenRoundPosFloatHalfToEvenSplatModule_basic(module, tu: TestUtils): + module.forward() + + +class AtenRoundInfSplatModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.const = torch.tensor([float("+inf")], dtype=torch.float32) + + @export + @annotate_args([None]) + def forward(self): + return torch.ops.aten.round(self.const) + + +@register_test_case(module_factory=lambda: AtenRoundInfSplatModule()) +def AtenRoundInfSplatModule_basic(module, tu: TestUtils): + module.forward() + + +class AtenRoundNanSplatModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.const = torch.tensor([float("nan")], dtype=torch.float32) + + @export + @annotate_args([None]) + def forward(self): + return torch.ops.aten.round(self.const) + + +@register_test_case(module_factory=lambda: AtenRoundNanSplatModule()) +def AtenRoundNanSplatModule_basic(module, tu: TestUtils): + module.forward() + + # ============================================================================== diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index 48092b71a875..31a76244eaa6 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -3596,3 +3596,61 @@ func.func @torch.aten.full$int_fold() -> !torch.vtensor<[2,1,4],si64> { %1 = torch.aten.full %0, %int-Inf, %none, %none, %none, %none : !torch.list, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,1,4],si64> return %1 : !torch.vtensor<[2,1,4],si64> } + +// ----- + +// CHECK-LABEL: @torch.aten.ceil$fold +// CHECK: %[[C:.*]] = torch.vtensor.literal(dense<-1.000000e+00> : tensor<2x2xf32>) +// CHECK: return %[[C]] +func.func @torch.aten.ceil$fold() -> !torch.vtensor<[2,2],f32> { + %cst = torch.vtensor.literal(dense<-1.100000e+00> : tensor<2x2xf32>) + : !torch.vtensor<[2,2],f32> + %r = torch.aten.ceil %cst : !torch.vtensor<[2,2],f32> -> !torch.vtensor<[2,2],f32> + return %r : !torch.vtensor<[2,2],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.floor$fold +// CHECK: %[[C:.*]] = torch.vtensor.literal(dense<1.000000e+00> : tensor<3x4xf32>) +// CHECK: return %[[C]] +func.func @torch.aten.floor$fold() -> !torch.vtensor<[3,4],f32> { + %cst = torch.vtensor.literal(dense<1.900000e+00> : tensor<3x4xf32>) + : !torch.vtensor<[3,4],f32> + %r = torch.aten.floor %cst : !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32> + return %r : !torch.vtensor<[3,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.trunc$fold +// CHECK: %[[C:.*]] = torch.vtensor.literal(dense<-3.000000e+00> : tensor<1x3xf32>) +// CHECK: return %[[C]] +func.func @torch.aten.trunc$fold() -> !torch.vtensor<[1,3],f32> { + %cst = torch.vtensor.literal(dense<-3.700000e+00> : tensor<1x3xf32>) + : !torch.vtensor<[1,3],f32> + %r = torch.aten.trunc %cst : !torch.vtensor<[1,3],f32> -> !torch.vtensor<[1,3],f32> + return %r : !torch.vtensor<[1,3],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.round$fold +// CHECK-DAG: %[[POS:.*]] = torch.vtensor.literal(dense<2.000000e+00> : tensor<4x5xf32>) +// CHECK-DAG: %[[NEG:.*]] = torch.vtensor.literal(dense<-2.000000e+00> : tensor<2x3xf32>) +// CHECK: return %[[POS]], %[[NEG]] +func.func @torch.aten.round$fold() + -> (!torch.vtensor<[4,5],f32>, !torch.vtensor<[2,3],f32>) { + %cpos = torch.vtensor.literal(dense<2.500000e+00> : tensor<4x5xf32>) + : !torch.vtensor<[4,5],f32> + %rpos = torch.aten.round %cpos + : !torch.vtensor<[4,5],f32> -> !torch.vtensor<[4,5],f32> + + %cneg = torch.vtensor.literal(dense<-2.500000e+00> : tensor<2x3xf32>) + : !torch.vtensor<[2,3],f32> + %rneg = torch.aten.round %cneg + : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32> + + return %rpos, %rneg + : !torch.vtensor<[4,5],f32>, !torch.vtensor<[2,3],f32> +}