Skip to content

Commit 7c71753

Browse files
committed
[linalg] : Use -realmax instead of -inf for MaxPool init.
1 parent 288cd5e commit 7c71753

File tree

2 files changed

+10
-10
lines changed

2 files changed

+10
-10
lines changed

lib/Conversion/TorchToLinalg/Pooling.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -441,9 +441,9 @@ class ConvertAtenMaxPoolOp : public OpConversionPattern<OpTy> {
441441
Value self = adaptor.getSelf();
442442
Type elementType = cast<RankedTensorType>(self.getType()).getElementType();
443443
TypedAttr smallestFPValueAttr = rewriter.getFloatAttr(
444-
elementType,
445-
APFloat::getInf(cast<mlir::FloatType>(elementType).getFloatSemantics(),
446-
/*Negative=*/true));
444+
elementType, APFloat::getLargest(
445+
cast<mlir::FloatType>(elementType).getFloatSemantics(),
446+
/*Negative=*/true));
447447
Value initValue =
448448
rewriter.create<arith::ConstantOp>(op->getLoc(), smallestFPValueAttr);
449449

@@ -693,7 +693,7 @@ class ConvertAtenMaxPoolOp : public OpConversionPattern<OpTy> {
693693
if (auto fpty = dyn_cast<mlir::FloatType>(elementType)) {
694694
smallestValueAttr = rewriter.getFloatAttr(
695695
elementType,
696-
APFloat::getInf(fpty.getFloatSemantics(), /*Negative=*/true));
696+
APFloat::getLargest(fpty.getFloatSemantics(), /*Negative=*/true));
697697
} else if (auto intTy = dyn_cast<mlir::IntegerType>(elementType)) {
698698
int64_t bw = intTy.getIntOrFloatBitWidth();
699699
smallestValueAttr = rewriter.getIntegerAttr(
@@ -1379,9 +1379,9 @@ class AdaptiveMaxPoolingHelper : public AdaptivePoolingHelper {
13791379
typeConverter->convertType(op.getResult1().getType()));
13801380
Type auxTensorElementType = auxTensorType.getElementType();
13811381
auto smallestFPValueAttr = rewriter.getFloatAttr(
1382-
elementType,
1383-
APFloat::getInf(cast<mlir::FloatType>(elementType).getFloatSemantics(),
1384-
/*Negative=*/true));
1382+
elementType, APFloat::getLargest(
1383+
cast<mlir::FloatType>(elementType).getFloatSemantics(),
1384+
/*Negative=*/true));
13851385
buffVal = rewriter.create<arith::ConstantOp>(loc, elementType,
13861386
smallestFPValueAttr);
13871387
auxTensor = rewriter.create<tensor::EmptyOp>(

test/Conversion/TorchToLinalg/pooling.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ func.func @forward_max_pool1d(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vten
77
%int3 = torch.constant.int 3
88
%int4 = torch.constant.int 4
99
%false = torch.constant.bool false
10-
// CHECK: %[[NEUTRAL:.*]] = arith.constant 0xFF800000 : f32
10+
// CHECK: %[[NEUTRAL:.*]] = arith.constant -3.40282347E+38 : f32
1111
// CHECK: %[[PADDED:.*]] = tensor.pad %{{.*}} low[0, 0, 3] high[0, 0, 3]
1212
// CHECK: %[[OUT:.*]] = linalg.fill ins(%[[NEUTRAL]] : f32) outs(%{{.*}} : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
1313
// CHECK: %[[INIT:.*]] = tensor.empty() : tensor<1xf32>
@@ -33,7 +33,7 @@ func.func @forward_max_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vt
3333
%int7 = torch.constant.int 7
3434
%int8 = torch.constant.int 8
3535
%false = torch.constant.bool false
36-
// CHECK: %[[NEUTRAL:.*]] = arith.constant 0xFF800000 : f32
36+
// CHECK: %[[NEUTRAL:.*]] = arith.constant -3.40282347E+38 : f32
3737
// CHECK: %[[PADDED:.*]] = tensor.pad %{{.*}} low[0, 0, 5, 6] high[0, 0, 5, 6]
3838
// CHECK: %[[OUT:.*]] = linalg.fill ins(%[[NEUTRAL]] : f32) outs(%{{.*}} : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
3939
// CHECK: %[[INIT:.*]] = tensor.empty() : tensor<1x2xf32>
@@ -106,7 +106,7 @@ func.func @forward_max_pool3d(%arg0: !torch.vtensor<[?,?,?,?,?],f32>) -> !torch.
106106

107107
%4 = torch.aten.max_pool3d %arg0, %kernel_size, %stride, %padding, %dilation, %false : !torch.vtensor<[?,?,?,?,?],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[?,?,?,?,?],f32>
108108

109-
// CHECK: %[[MIN_VALUE:.*]] = arith.constant 0xFF800000 : f32
109+
// CHECK: %[[MIN_VALUE:.*]] = arith.constant -3.40282347E+38 : f32
110110
// CHECK: %[[PADDED_INPUT_TENSOR:.*]] = tensor.pad %{{.*}} low[0, 0, 4, 4, 4] high[0, 0, 4, 4, 4] {
111111
// CHECK-NEXT: ^bb0(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index):
112112
// CHECK-NEXT: tensor.yield %[[MIN_VALUE:.*]] : f32

0 commit comments

Comments
 (0)