Skip to content

Commit 05be717

Browse files
committed
[linalg] : Use (-)realmax instead of (-)inf to avoid usage of non-finites.
1 parent 288cd5e commit 05be717

File tree

7 files changed

+25
-25
lines changed

7 files changed

+25
-25
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1548,8 +1548,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
15481548
auto dty = dataTy.getDtype();
15491549
Value scalar;
15501550
if (FloatType fpTy = dyn_cast<FloatType>(dty)) {
1551-
auto inf =
1552-
APFloat::getInf(fpTy.getFloatSemantics(), /*Negative=*/true);
1551+
auto inf = APFloat::getLargest(fpTy.getFloatSemantics(),
1552+
/*Negative=*/true);
15531553
scalar = rewriter.create<Torch::ConstantFloatOp>(
15541554
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
15551555
rewriter.getFloatAttr(rewriter.getF64Type(),

lib/Conversion/TorchToLinalg/Pooling.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -441,9 +441,9 @@ class ConvertAtenMaxPoolOp : public OpConversionPattern<OpTy> {
441441
Value self = adaptor.getSelf();
442442
Type elementType = cast<RankedTensorType>(self.getType()).getElementType();
443443
TypedAttr smallestFPValueAttr = rewriter.getFloatAttr(
444-
elementType,
445-
APFloat::getInf(cast<mlir::FloatType>(elementType).getFloatSemantics(),
446-
/*Negative=*/true));
444+
elementType, APFloat::getLargest(
445+
cast<mlir::FloatType>(elementType).getFloatSemantics(),
446+
/*Negative=*/true));
447447
Value initValue =
448448
rewriter.create<arith::ConstantOp>(op->getLoc(), smallestFPValueAttr);
449449

@@ -693,7 +693,7 @@ class ConvertAtenMaxPoolOp : public OpConversionPattern<OpTy> {
693693
if (auto fpty = dyn_cast<mlir::FloatType>(elementType)) {
694694
smallestValueAttr = rewriter.getFloatAttr(
695695
elementType,
696-
APFloat::getInf(fpty.getFloatSemantics(), /*Negative=*/true));
696+
APFloat::getLargest(fpty.getFloatSemantics(), /*Negative=*/true));
697697
} else if (auto intTy = dyn_cast<mlir::IntegerType>(elementType)) {
698698
int64_t bw = intTy.getIntOrFloatBitWidth();
699699
smallestValueAttr = rewriter.getIntegerAttr(
@@ -1379,9 +1379,9 @@ class AdaptiveMaxPoolingHelper : public AdaptivePoolingHelper {
13791379
typeConverter->convertType(op.getResult1().getType()));
13801380
Type auxTensorElementType = auxTensorType.getElementType();
13811381
auto smallestFPValueAttr = rewriter.getFloatAttr(
1382-
elementType,
1383-
APFloat::getInf(cast<mlir::FloatType>(elementType).getFloatSemantics(),
1384-
/*Negative=*/true));
1382+
elementType, APFloat::getLargest(
1383+
cast<mlir::FloatType>(elementType).getFloatSemantics(),
1384+
/*Negative=*/true));
13851385
buffVal = rewriter.create<arith::ConstantOp>(loc, elementType,
13861386
smallestFPValueAttr);
13871387
auxTensor = rewriter.create<tensor::EmptyOp>(

lib/Conversion/TorchToLinalg/Reduction.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern<OpTy> {
117117
fillValue = rewriter.create<arith::ConstantOp>(
118118
loc, rewriter.getFloatAttr(
119119
inElementType,
120-
APFloat::getInf(
120+
APFloat::getLargest(
121121
cast<mlir::FloatType>(inElementType).getFloatSemantics(),
122122
/*Negative=*/isMax)));
123123
} else if (!isUnsigned) {
@@ -302,7 +302,7 @@ static Value createInitElementForReduceOp(OpBuilder &b, Location loc,
302302
return b.create<arith::ConstantOp>(
303303
loc, b.getFloatAttr(
304304
elementType,
305-
APFloat::getInf(
305+
APFloat::getLargest(
306306
cast<mlir::FloatType>(elementType).getFloatSemantics(),
307307
/*Negative=*/true)));
308308
else if (isa<mlir::IntegerType>(elementType) &&
@@ -318,7 +318,7 @@ static Value createInitElementForReduceOp(OpBuilder &b, Location loc,
318318
return b.create<arith::ConstantOp>(
319319
loc, b.getFloatAttr(
320320
elementType,
321-
APFloat::getInf(
321+
APFloat::getLargest(
322322
cast<mlir::FloatType>(elementType).getFloatSemantics(),
323323
/*Negative=*/false)));
324324
else if (isa<mlir::IntegerType>(elementType) &&

lib/Conversion/TorchToStablehlo/Reduction.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,9 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
6262
if (isa<AtenAmaxOp, AtenMaxOp, AtenMaxDimOp, AtenArgmaxOp>(op)) {
6363
if (isa<mlir::FloatType>(elementTy)) {
6464
constAttr = DenseElementsAttr::get(
65-
constType,
66-
{APFloat::getInf(cast<mlir::FloatType>(elementTy).getFloatSemantics(),
67-
/*negative=*/true)});
65+
constType, {APFloat::getLargest(
66+
cast<mlir::FloatType>(elementTy).getFloatSemantics(),
67+
/*negative=*/true)});
6868
} else if (isa<mlir::IntegerType>(elementTy)) {
6969
constAttr = DenseElementsAttr::get(
7070
constType,
@@ -75,9 +75,9 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
7575
if (isa<AtenAminOp, AtenMinOp, AtenMinDimOp, AtenArgminOp>(op)) {
7676
if (isa<mlir::FloatType>(elementTy)) {
7777
constAttr = DenseElementsAttr::get(
78-
constType,
79-
{APFloat::getInf(cast<mlir::FloatType>(elementTy).getFloatSemantics(),
80-
/*negative=*/false)});
78+
constType, {APFloat::getLargest(
79+
cast<mlir::FloatType>(elementTy).getFloatSemantics(),
80+
/*negative=*/false)});
8181
} else if (isa<mlir::IntegerType>(elementTy)) {
8282
constAttr = DenseElementsAttr::get(
8383
constType,

lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2072,15 +2072,15 @@ class ConvertAtenKthvalueOp : public OpConversionPattern<AtenKthvalueOp> {
20722072
loc,
20732073
rewriter.getFloatAttr(
20742074
inputElementType,
2075-
APFloat::getInf(
2075+
APFloat::getLargest(
20762076
cast<mlir::FloatType>(inputElementType).getFloatSemantics(),
20772077
/*Negative=*/false)));
20782078
// min float for linalg generic op tensor
20792079
fillValLinalgFindMax = rewriter.create<arith::ConstantOp>(
20802080
loc,
20812081
rewriter.getFloatAttr(
20822082
inputElementType,
2083-
APFloat::getInf(
2083+
APFloat::getLargest(
20842084
cast<mlir::FloatType>(inputElementType).getFloatSemantics(),
20852085
/*Negative=*/true)));
20862086
} else if (!isUnsigned) {

test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -805,13 +805,13 @@ func.func @test_selu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,
805805

806806
// CHECK-LABEL: func.func @test_reduce_max_empty_set_fp
807807
func.func @test_reduce_max_empty_set_fp(%arg0: !torch.vtensor<[2,0,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,1,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
808-
// CHECK-DAG: %[[INF:.+]] = torch.constant.float 0xFFF0000000000000
808+
// CHECK-DAG: %[[NEGMAX:.+]] = torch.constant.float -3.4028234663852886E+38
809809
// CHECK-DAG: %[[INT2:.+]] = torch.constant.int 2
810810
// CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1
811811
// CHECK-DAG: %[[INT4:.+]] = torch.constant.int 4
812812
// CHECK-DAG: %[[NONE:.+]] = torch.constant.none
813813
// CHECK-DAG: %[[LIST:.+]] = torch.prim.ListConstruct %[[INT2]], %[[INT1]], %[[INT4]]
814-
// CHECK-DAG: %[[FULL:.+]] = torch.aten.full %[[LIST]], %[[INF]], %[[NONE]], %[[NONE]], %[[NONE]]
814+
// CHECK-DAG: %[[FULL:.+]] = torch.aten.full %[[LIST]], %[[NEGMAX]], %[[NONE]], %[[NONE]], %[[NONE]]
815815
// CHECK: return %[[FULL]]
816816
%0 = torch.operator "onnx.ReduceMax"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[2,0,4],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,1,4],f32>
817817
return %0 : !torch.vtensor<[2,1,4],f32>

test/Conversion/TorchToLinalg/pooling.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ func.func @forward_max_pool1d(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vten
77
%int3 = torch.constant.int 3
88
%int4 = torch.constant.int 4
99
%false = torch.constant.bool false
10-
// CHECK: %[[NEUTRAL:.*]] = arith.constant 0xFF800000 : f32
10+
// CHECK: %[[NEUTRAL:.*]] = arith.constant -3.40282347E+38 : f32
1111
// CHECK: %[[PADDED:.*]] = tensor.pad %{{.*}} low[0, 0, 3] high[0, 0, 3]
1212
// CHECK: %[[OUT:.*]] = linalg.fill ins(%[[NEUTRAL]] : f32) outs(%{{.*}} : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
1313
// CHECK: %[[INIT:.*]] = tensor.empty() : tensor<1xf32>
@@ -33,7 +33,7 @@ func.func @forward_max_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vt
3333
%int7 = torch.constant.int 7
3434
%int8 = torch.constant.int 8
3535
%false = torch.constant.bool false
36-
// CHECK: %[[NEUTRAL:.*]] = arith.constant 0xFF800000 : f32
36+
// CHECK: %[[NEUTRAL:.*]] = arith.constant -3.40282347E+38 : f32
3737
// CHECK: %[[PADDED:.*]] = tensor.pad %{{.*}} low[0, 0, 5, 6] high[0, 0, 5, 6]
3838
// CHECK: %[[OUT:.*]] = linalg.fill ins(%[[NEUTRAL]] : f32) outs(%{{.*}} : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
3939
// CHECK: %[[INIT:.*]] = tensor.empty() : tensor<1x2xf32>
@@ -106,7 +106,7 @@ func.func @forward_max_pool3d(%arg0: !torch.vtensor<[?,?,?,?,?],f32>) -> !torch.
106106

107107
%4 = torch.aten.max_pool3d %arg0, %kernel_size, %stride, %padding, %dilation, %false : !torch.vtensor<[?,?,?,?,?],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[?,?,?,?,?],f32>
108108

109-
// CHECK: %[[MIN_VALUE:.*]] = arith.constant 0xFF800000 : f32
109+
// CHECK: %[[MIN_VALUE:.*]] = arith.constant -3.40282347E+38 : f32
110110
// CHECK: %[[PADDED_INPUT_TENSOR:.*]] = tensor.pad %{{.*}} low[0, 0, 4, 4, 4] high[0, 0, 4, 4, 4] {
111111
// CHECK-NEXT: ^bb0(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index):
112112
// CHECK-NEXT: tensor.yield %[[MIN_VALUE:.*]] : f32

0 commit comments

Comments
 (0)