Skip to content

Commit 175d2a2

Browse files
committed
[Torch] Fold aten rounding ops on splat constants.
This commit teaches the folding methods of `AtenFloor`, `AtenCeil`, `AtenRound`, and `AtenTruc` to constant-fold roundings when the operand is a splat `DenseElementsAttr`.
1 parent 8b77de9 commit 175d2a2

File tree

3 files changed

+191
-12
lines changed

3 files changed

+191
-12
lines changed

lib/Dialect/Torch/IR/TorchOps.cpp

Lines changed: 69 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,28 @@ static Value getScalarFloatValue(Value input, Location loc,
234234
return nullptr;
235235
}
236236

237+
// Common helper for splat-only rounding-based folders.
238+
static OpFoldResult foldSplatRounding(ValueTensorType resultType,
239+
Attribute selfAttr,
240+
APFloat::roundingMode mode) {
241+
auto elems = dyn_cast_or_null<DenseElementsAttr>(selfAttr);
242+
if (!elems || !elems.isSplat())
243+
return {};
244+
245+
if (!isa<mlir::FloatType>(resultType.getDtype()))
246+
return {};
247+
248+
auto outShaped = resultType.toBuiltinTensor();
249+
if (!outShaped.hasStaticShape())
250+
return {};
251+
252+
APFloat v = elems.getSplatValue<APFloat>();
253+
// NaNs and infs are dealt with consistently with torch, so side-effects
254+
// can be discarded.
255+
(void)v.roundToIntegral(mode);
256+
return DenseElementsAttr::get(outShaped, v);
257+
}
258+
237259
//===----------------------------------------------------------------------===//
238260
// MethodOp
239261
//===----------------------------------------------------------------------===//
@@ -2064,10 +2086,19 @@ OpFoldResult AtenLogOp::fold(FoldAdaptor adaptor) {
20642086

20652087
OpFoldResult AtenFloorOp::fold(FoldAdaptor adaptor) {
20662088
auto resultType = dyn_cast<ValueTensorType>(getType());
2067-
if (resultType && resultType.hasDtype() &&
2068-
isa<mlir::IntegerType>(resultType.getDtype())) {
2089+
2090+
if (!resultType || !resultType.hasDtype())
2091+
return {};
2092+
2093+
// No-op if the result is int, fold.
2094+
if (isa<mlir::IntegerType>(resultType.getDtype()))
20692095
return getSelf();
2070-
}
2096+
2097+
// Fold float splats.
2098+
if (auto res = foldSplatRounding(resultType, /*selfAttr*/ adaptor.getSelf(),
2099+
/*mode*/ APFloat::rmTowardNegative))
2100+
return res;
2101+
20712102
return {};
20722103
}
20732104

@@ -2077,10 +2108,19 @@ OpFoldResult AtenFloorOp::fold(FoldAdaptor adaptor) {
20772108

20782109
OpFoldResult AtenCeilOp::fold(FoldAdaptor adaptor) {
20792110
auto resultType = dyn_cast<ValueTensorType>(getType());
2080-
if (resultType && resultType.hasDtype() &&
2081-
isa<mlir::IntegerType>(resultType.getDtype())) {
2111+
2112+
if (!resultType || !resultType.hasDtype())
2113+
return {};
2114+
2115+
// No-op if the result is int, fold.
2116+
if (isa<mlir::IntegerType>(resultType.getDtype()))
20822117
return getSelf();
2083-
}
2118+
2119+
// Fold float splats.
2120+
if (auto res = foldSplatRounding(resultType, /*selfAttr*/ adaptor.getSelf(),
2121+
/*mode*/ APFloat::rmTowardPositive))
2122+
return res;
2123+
20842124
return {};
20852125
}
20862126

@@ -2103,10 +2143,18 @@ OpFoldResult AtenRoundDecimalsOp::fold(FoldAdaptor adaptor) {
21032143

21042144
OpFoldResult AtenRoundOp::fold(FoldAdaptor adaptor) {
21052145
auto resultType = dyn_cast<ValueTensorType>(getType());
2106-
if (resultType && resultType.hasDtype() &&
2107-
isa<mlir::IntegerType>(resultType.getDtype())) {
2146+
if (!resultType || !resultType.hasDtype())
2147+
return {};
2148+
2149+
// No-op if the result is int, fold.
2150+
if (isa<mlir::IntegerType>(resultType.getDtype()))
21082151
return getSelf();
2109-
}
2152+
2153+
// Fold float splats.
2154+
if (auto res = foldSplatRounding(resultType, /*selfAttr*/ adaptor.getSelf(),
2155+
/*mode*/ APFloat::rmNearestTiesToEven))
2156+
return res;
2157+
21102158
return {};
21112159
}
21122160

@@ -2116,10 +2164,19 @@ OpFoldResult AtenRoundOp::fold(FoldAdaptor adaptor) {
21162164

21172165
OpFoldResult AtenTruncOp::fold(FoldAdaptor adaptor) {
21182166
auto resultType = dyn_cast<ValueTensorType>(getType());
2119-
if (resultType && resultType.hasDtype() &&
2120-
isa<mlir::IntegerType>(resultType.getDtype())) {
2167+
2168+
if (!resultType || !resultType.hasDtype())
2169+
return {};
2170+
2171+
// No-op if the result is int, fold.
2172+
if (isa<mlir::IntegerType>(resultType.getDtype()))
21212173
return getSelf();
2122-
}
2174+
2175+
// Fold float splats.
2176+
if (auto res = foldSplatRounding(resultType, /*selfAttr*/ adaptor.getSelf(),
2177+
/*mode*/ APFloat::rmTowardZero))
2178+
return res;
2179+
21232180
return {};
21242181
}
21252182

projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6844,6 +6844,70 @@ def AtenRoundFloatDecimalsModule_basic(module, tu: TestUtils):
68446844
module.forward(tu.rand(5, 5, low=-3.0, high=3.0))
68456845

68466846

6847+
class AtenRoundNegFloatHalfToEvenSplatModule(torch.nn.Module):
6848+
def __init__(self):
6849+
super().__init__()
6850+
self.const = torch.tensor([-1.5, -1.5], dtype=torch.float32)
6851+
6852+
@export
6853+
@annotate_args([None])
6854+
def forward(self):
6855+
return torch.ops.aten.round(self.const)
6856+
6857+
6858+
@register_test_case(module_factory=lambda: AtenRoundNegFloatHalfToEvenSplatModule())
6859+
def AtenRoundNegFloatHalfToEvenSplatModule_basic(module, tu: TestUtils):
6860+
module.forward()
6861+
6862+
6863+
class AtenRoundPosFloatHalfToEvenSplatModule(torch.nn.Module):
6864+
def __init__(self):
6865+
super().__init__()
6866+
self.const = torch.tensor([1.5, 1.5], dtype=torch.float32)
6867+
6868+
@export
6869+
@annotate_args([None])
6870+
def forward(self):
6871+
return torch.ops.aten.round(self.const)
6872+
6873+
6874+
@register_test_case(module_factory=lambda: AtenRoundPosFloatHalfToEvenSplatModule())
6875+
def AtenRoundPosFloatHalfToEvenSplatModule_basic(module, tu: TestUtils):
6876+
module.forward()
6877+
6878+
6879+
class AtenRoundInfSplatModule(torch.nn.Module):
6880+
def __init__(self):
6881+
super().__init__()
6882+
self.const = torch.tensor([float("+inf")], dtype=torch.float32)
6883+
6884+
@export
6885+
@annotate_args([None])
6886+
def forward(self):
6887+
return torch.ops.aten.round(self.const)
6888+
6889+
6890+
@register_test_case(module_factory=lambda: AtenRoundInfSplatModule())
6891+
def AtenRoundInfSplatModule_basic(module, tu: TestUtils):
6892+
module.forward()
6893+
6894+
6895+
class AtenRoundNanSplatModule(torch.nn.Module):
6896+
def __init__(self):
6897+
super().__init__()
6898+
self.const = torch.tensor([float("nan")], dtype=torch.float32)
6899+
6900+
@export
6901+
@annotate_args([None])
6902+
def forward(self):
6903+
return torch.ops.aten.round(self.const)
6904+
6905+
6906+
@register_test_case(module_factory=lambda: AtenRoundNanSplatModule())
6907+
def AtenRoundNanSplatModule_basic(module, tu: TestUtils):
6908+
module.forward()
6909+
6910+
68476911
# ==============================================================================
68486912

68496913

test/Dialect/Torch/canonicalize.mlir

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3596,3 +3596,61 @@ func.func @torch.aten.full$int_fold() -> !torch.vtensor<[2,1,4],si64> {
35963596
%1 = torch.aten.full %0, %int-Inf, %none, %none, %none, %none : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,1,4],si64>
35973597
return %1 : !torch.vtensor<[2,1,4],si64>
35983598
}
3599+
3600+
// -----
3601+
3602+
// CHECK-LABEL: @torch.aten.ceil$fold
3603+
// CHECK: %[[C:.*]] = torch.vtensor.literal(dense<-1.000000e+00> : tensor<2x2xf32>)
3604+
// CHECK: return %[[C]]
3605+
func.func @torch.aten.ceil$fold() -> !torch.vtensor<[2,2],f32> {
3606+
%cst = torch.vtensor.literal(dense<-1.100000e+00> : tensor<2x2xf32>)
3607+
: !torch.vtensor<[2,2],f32>
3608+
%r = torch.aten.ceil %cst : !torch.vtensor<[2,2],f32> -> !torch.vtensor<[2,2],f32>
3609+
return %r : !torch.vtensor<[2,2],f32>
3610+
}
3611+
3612+
// -----
3613+
3614+
// CHECK-LABEL: func.func @torch.aten.floor$fold
3615+
// CHECK: %[[C:.*]] = torch.vtensor.literal(dense<1.000000e+00> : tensor<3x4xf32>)
3616+
// CHECK: return %[[C]]
3617+
func.func @torch.aten.floor$fold() -> !torch.vtensor<[3,4],f32> {
3618+
%cst = torch.vtensor.literal(dense<1.900000e+00> : tensor<3x4xf32>)
3619+
: !torch.vtensor<[3,4],f32>
3620+
%r = torch.aten.floor %cst : !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32>
3621+
return %r : !torch.vtensor<[3,4],f32>
3622+
}
3623+
3624+
// -----
3625+
3626+
// CHECK-LABEL: func.func @torch.aten.trunc$fold
3627+
// CHECK: %[[C:.*]] = torch.vtensor.literal(dense<-3.000000e+00> : tensor<1x3xf32>)
3628+
// CHECK: return %[[C]]
3629+
func.func @torch.aten.trunc$fold() -> !torch.vtensor<[1,3],f32> {
3630+
%cst = torch.vtensor.literal(dense<-3.700000e+00> : tensor<1x3xf32>)
3631+
: !torch.vtensor<[1,3],f32>
3632+
%r = torch.aten.trunc %cst : !torch.vtensor<[1,3],f32> -> !torch.vtensor<[1,3],f32>
3633+
return %r : !torch.vtensor<[1,3],f32>
3634+
}
3635+
3636+
// -----
3637+
3638+
// CHECK-LABEL: func.func @torch.aten.round$fold
3639+
// CHECK-DAG: %[[POS:.*]] = torch.vtensor.literal(dense<2.000000e+00> : tensor<4x5xf32>)
3640+
// CHECK-DAG: %[[NEG:.*]] = torch.vtensor.literal(dense<-2.000000e+00> : tensor<2x3xf32>)
3641+
// CHECK: return %[[POS]], %[[NEG]]
3642+
func.func @torch.aten.round$fold()
3643+
-> (!torch.vtensor<[4,5],f32>, !torch.vtensor<[2,3],f32>) {
3644+
%cpos = torch.vtensor.literal(dense<2.500000e+00> : tensor<4x5xf32>)
3645+
: !torch.vtensor<[4,5],f32>
3646+
%rpos = torch.aten.round %cpos
3647+
: !torch.vtensor<[4,5],f32> -> !torch.vtensor<[4,5],f32>
3648+
3649+
%cneg = torch.vtensor.literal(dense<-2.500000e+00> : tensor<2x3xf32>)
3650+
: !torch.vtensor<[2,3],f32>
3651+
%rneg = torch.aten.round %cneg
3652+
: !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32>
3653+
3654+
return %rpos, %rneg
3655+
: !torch.vtensor<[4,5],f32>, !torch.vtensor<[2,3],f32>
3656+
}

0 commit comments

Comments
 (0)