Skip to content

Commit 0bd2f12

Browse files
authored
[mlir][linalg] Restrict fill initial value type to output element type (#169567)
Disallow implicit casting, which is surprising, and, IME, usually indicative of copy-paste errors. Because the initial value must be a scalar, I don't expect this to affect any data movement.
1 parent b228256 commit 0bd2f12

File tree

11 files changed

+81
-75
lines changed

11 files changed

+81
-75
lines changed

mlir/docs/Dialects/Linalg/OpDSL.md

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -311,16 +311,17 @@ An example for a rank polymorphic operation is `fill`:
311311

312312
```python
313313
@linalg_structured_op
314-
def fill(value=ScalarDef(T1),
315-
O=TensorDef(U, output=True)):
316-
O[None] = TypeFn.cast_signed(U, value)
314+
def fill(value=ScalarDef(T),
315+
O=TensorDef(T, output=True)):
316+
O[None] = value
317317
```
318318

319-
The operation sets the elements of the output tensor `O` to `value`. All
320-
operands are either scalars or rank zero tensors that are accessed using the
321-
index `None`. The operation thus performs a scalar computation that trivially
322-
extends to a multi-dimensional pointwise computation. As a result, we may use
323-
`fill` with arbitrary ranked output tensors:
319+
The operation sets the elements of the output tensor `O` to `value`. The value
320+
type must match the element type of the output tensor. All operands are either
321+
scalars or rank zero tensors that are accessed using the index `None`. The
322+
operation thus performs a scalar computation that trivially extends to a
323+
multi-dimensional pointwise computation. As a result, we may use `fill` with
324+
arbitrary ranked output tensors:
324325

325326
```python
326327
tensor_2d = tensor.EmptyOp([4, 8], f32)

mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6054,9 +6054,9 @@ metadata: !LinalgOpMetadata
60546054
doc: |-
60556055
Fills the output tensor with the given value.
60566056
6057-
Works for arbitrary ranked output tensors since the operation performs scalar
6058-
accesses only and is thus rank polymorphic. Numeric casting is performed on
6059-
the value operand, promoting it to the same data type as the output.
6057+
Works for arbitrary ranked output tensors since the operation performs
6058+
scalar accesses only and is thus rank polymorphic. The value operand
6059+
type must match the element type of the output.
60606060
implements:
60616061
- LinalgFillOpInterface
60626062
defines:
@@ -6066,11 +6066,11 @@ structured_op: !LinalgStructuredOpConfig
60666066
- !LinalgOperandDefConfig
60676067
name: value
60686068
kind: scalar
6069-
type_var: T1
6069+
type_var: T
60706070
- !LinalgOperandDefConfig
60716071
name: O
60726072
kind: output_tensor
6073-
type_var: U
6073+
type_var: T
60746074
shape_map: affine_map<() -> ()>
60756075
indexing_maps: !LinalgIndexingMapsConfig
60766076
static_indexing_maps:
@@ -6081,13 +6081,7 @@ structured_op: !LinalgStructuredOpConfig
60816081
- !ScalarAssign
60826082
arg: O
60836083
value: !ScalarExpression
6084-
scalar_fn:
6085-
kind: type
6086-
fn_name: cast_signed
6087-
type_var: U
6088-
operands:
6089-
- !ScalarExpression
6090-
scalar_arg: value
6084+
scalar_arg: value
60916085
--- !LinalgOpConfig
60926086
metadata: !LinalgOpMetadata
60936087
name: fill_rng_2d

mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1057,12 +1057,15 @@ LogicalResult mlir::linalg::detail::verifyConvolutionInterface(Operation *op) {
10571057
// FillOpInterface implementation
10581058
//===----------------------------------------------------------------------===//
10591059

1060+
namespace {
10601061
enum class MatchFillResult {
10611062
Success = 0,
10621063
NotLinalgOp,
10631064
WrongNumOperands,
1064-
NotScalarInput
1065+
NotScalarInput,
1066+
TypeMismatch
10651067
};
1068+
} // namespace
10661069

10671070
static MatchFillResult isFillInterfaceImpl(Operation *op) {
10681071
auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
@@ -1075,17 +1078,33 @@ static MatchFillResult isFillInterfaceImpl(Operation *op) {
10751078
if (!linalgOp.isScalar(value))
10761079
return MatchFillResult::NotScalarInput;
10771080

1081+
// Check that the scalar input type matches the output element type.
1082+
OpOperand *output = linalgOp.getDpsInitOperand(0);
1083+
Type scalarType = value->get().getType();
1084+
Type outputElementType = getElementTypeOrSelf(output->get().getType());
1085+
if (scalarType != outputElementType)
1086+
return MatchFillResult::TypeMismatch;
1087+
10781088
return MatchFillResult::Success;
10791089
}
10801090

10811091
LogicalResult mlir::linalg::detail::verifyFillInterface(Operation *op) {
1082-
auto res = isFillInterfaceImpl(op);
1092+
MatchFillResult res = isFillInterfaceImpl(op);
10831093
if (res == MatchFillResult::NotLinalgOp)
10841094
return op->emitError("expected a LinalgOp");
10851095
if (res == MatchFillResult::WrongNumOperands)
10861096
return op->emitError("expected op with 1 input and 1 output");
10871097
if (res == MatchFillResult::NotScalarInput)
10881098
return op->emitError("expected op with scalar input");
1099+
if (res == MatchFillResult::TypeMismatch) {
1100+
auto linalgOp = cast<linalg::LinalgOp>(op);
1101+
Type scalarType = linalgOp.getDpsInputOperand(0)->get().getType();
1102+
Type outputElementType =
1103+
getElementTypeOrSelf(linalgOp.getDpsInitOperand(0)->get().getType());
1104+
return op->emitOpError("expected fill value type (")
1105+
<< scalarType << ") to match output element type ("
1106+
<< outputElementType << ")";
1107+
}
10891108

10901109
return success();
10911110
}

mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1729,16 +1729,16 @@ def pooling_ndhwc_min(
17291729

17301730

17311731
@linalg_structured_op
1732-
def fill(value=ScalarDef(T1), O=TensorDef(U, output=True)):
1732+
def fill(value=ScalarDef(T), O=TensorDef(T, output=True)):
17331733
"""Fills the output tensor with the given value.
17341734
17351735
Works for arbitrary ranked output tensors since the operation performs scalar
1736-
accesses only and is thus rank polymorphic. Numeric casting is performed on
1737-
the value operand, promoting it to the same data type as the output.
1736+
accesses only and is thus rank polymorphic. The value type must match the
1737+
element type of the output tensor or memref.
17381738
"""
17391739
implements(FillOpInterface)
17401740
defines(Canonicalizer)
1741-
O[None] = TypeFn.cast_signed(U, value)
1741+
O[None] = value
17421742

17431743

17441744
@linalg_structured_op

mlir/test/Dialect/Affine/value-bounds-reification.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,13 @@ func.func @reify_through_chain(%sz0: index, %sz2: index) -> (index, index, index
3636
// CHECK: "test.some_use"(%[[c5]])
3737
// CHECK: %[[c5:.*]] = arith.constant 5 : index
3838
// CHECK: "test.some_use"(%[[c5]])
39-
func.func @reify_slice_bound(%t: tensor<?x?xi32>, %idx: index, %ub: index, %f: f32) {
39+
func.func @reify_slice_bound(%t: tensor<?x?xi32>, %idx: index, %ub: index, %f: i32) {
4040
%c0 = arith.constant 0 : index
4141
%c4 = arith.constant 4 : index
4242
scf.for %iv = %c0 to %ub step %c4 {
4343
%sz = affine.min affine_map<(d0)[s0] -> (-d0 + s0, 4)>(%iv)[%ub]
4444
%slice = tensor.extract_slice %t[%idx, %iv] [1, %sz] [1, 1] : tensor<?x?xi32> to tensor<1x?xi32>
45-
%filled = linalg.fill ins(%f : f32) outs(%slice : tensor<1x?xi32>) -> tensor<1x?xi32>
45+
%filled = linalg.fill ins(%f : i32) outs(%slice : tensor<1x?xi32>) -> tensor<1x?xi32>
4646

4747
%bound = "test.reify_bound"(%filled) {dim = 1, type = "UB"} : (tensor<1x?xi32>) -> (index)
4848
"test.some_use"(%bound) : (index) -> ()

mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -921,30 +921,6 @@ func.func @fold_fill_generic_basic(%arg0: tensor<?xf32>) -> (tensor<?xf32>) {
921921

922922
// -----
923923

924-
// CHECK-LABEL: func @fold_fill_generic_different_dtype
925-
// CHECK-SAME: (%[[ARG0:.*]]: tensor<?xf16>) -> tensor<?xf16> {
926-
// CHECK-NOT: linalg.fill
927-
// CHECK: %[[GENERIC_OP:.*]] = linalg.generic
928-
// CHECK-SAME: ins(%[[ARG0]] : tensor<?xf16>)
929-
// CHECK-SAME: outs({{.*}} : tensor<?xf16>) {
930-
#map0 = affine_map<(d0) -> (d0)>
931-
func.func @fold_fill_generic_different_dtype(%arg0: tensor<?xf16>) -> (tensor<?xf16>) {
932-
%c0 = arith.constant 0 : index
933-
%cst = arith.constant 7.0 : f32
934-
%0 = tensor.dim %arg0, %c0 : tensor<?xf16>
935-
%1 = tensor.empty(%0) : tensor<?xf16>
936-
%2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<?xf16>) -> tensor<?xf16>
937-
%3 = tensor.empty(%0) : tensor<?xf16>
938-
%4 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types=["parallel"]} ins(%arg0, %2 : tensor<?xf16>, tensor<?xf16>) outs (%3:tensor<?xf16>) {
939-
^bb0(%arg1: f16, %arg2: f16, %arg3: f16):
940-
%5 = arith.addf %arg1, %arg2 : f16
941-
linalg.yield %5 : f16
942-
} -> tensor<?xf16>
943-
return %4 : tensor<?xf16>
944-
}
945-
946-
// -----
947-
948924
// CHECK-LABEL: func @fold_fill_generic_mixedaccess
949925
// CHECK-NOT: linalg.fill
950926
// CHECK: %[[GENERIC_OP:.*]] = linalg.generic
@@ -1079,4 +1055,4 @@ module {
10791055
// CHECK-NOT: linalg.generic
10801056
// CHECK: tensor.expand_shape
10811057
// CHECK: linalg.generic {{.*}}, iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "reduction"]}
1082-
// CHECK-SAME: ins(%[[ARG0]], %[[FUSED]]#1 : tensor<1x1x2x1xf32>, tensor<4x1x1x1xf32>)
1058+
// CHECK-SAME: ins(%[[ARG0]], %[[FUSED]]#1 : tensor<1x1x2x1xf32>, tensor<4x1x1x1xf32>)

mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -380,8 +380,8 @@ func.func @generalize_pooling_nwc_sum_i32(%input : tensor<1x16x1xi32>, %shape: t
380380

381381
// -----
382382

383-
func.func @generalize_fill_0d(%value: f64, %O: tensor<f32>) -> tensor<f32> {
384-
%0 = linalg.fill ins(%value: f64) outs(%O : tensor<f32>) -> tensor<f32>
383+
func.func @generalize_fill_0d(%value: f32, %O: tensor<f32>) -> tensor<f32> {
384+
%0 = linalg.fill ins(%value: f32) outs(%O : tensor<f32>) -> tensor<f32>
385385
return %0: tensor<f32>
386386
}
387387

@@ -394,8 +394,8 @@ func.func @generalize_fill_0d(%value: f64, %O: tensor<f32>) -> tensor<f32> {
394394

395395
// -----
396396

397-
func.func @generalize_fill_2d(%value: f64, %O: memref<16x32xf32>) {
398-
linalg.fill ins(%value: f64) outs(%O : memref<16x32xf32>)
397+
func.func @generalize_fill_2d(%value: f32, %O: memref<16x32xf32>) {
398+
linalg.fill ins(%value: f32) outs(%O : memref<16x32xf32>)
399399
return
400400
}
401401

mlir/test/Dialect/Linalg/invalid.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,24 @@ func.func @illegal_fill_tensor_with_memref_return
352352

353353
// -----
354354

355+
func.func @illegal_fill_element_type_truncation(%arg0 : tensor<2xf32>, %arg1 : f64) -> tensor<2xf32>
356+
{
357+
// expected-error @+1 {{'linalg.fill' op expected fill value type ('f64') to match output element type ('f32')}}
358+
%0 = linalg.fill ins(%arg1 : f64) outs(%arg0 : tensor<2xf32>) -> tensor<2xf32>
359+
return %0 : tensor<2xf32>
360+
}
361+
362+
// -----
363+
364+
func.func @illegal_fill_element_type_extension(%arg0 : tensor<2xi32>, %arg1 : i16) -> tensor<2xi32>
365+
{
366+
// expected-error @+1 {{'linalg.fill' op expected fill value type ('i16') to match output element type ('i32')}}
367+
%0 = linalg.fill ins(%arg1 : i16) outs(%arg0 : tensor<2xi32>) -> tensor<2xi32>
368+
return %0 : tensor<2xi32>
369+
}
370+
371+
// -----
372+
355373
func.func @illegal_fill_value_type(%arg0 : tensor<2x2xf32>, %arg1 : tensor<2xf32>) -> tensor<2x2xf32>
356374
{
357375
// expected-error @+1 {{expected op with scalar input}}

mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ func.func @main() {
2727
%A_dyn = tensor.cast %A : tensor<8x2xf32> to tensor<?x?xf32>
2828
%B_dyn = tensor.cast %B : tensor<2x4xf32> to tensor<?x?xf32>
2929

30-
%c0_i32 = arith.constant 0 : i32
31-
%C_init = linalg.fill ins(%c0_i32 : i32) outs(%C_dyn : tensor<?x?xf32>) -> tensor<?x?xf32>
30+
%c0_f32 = arith.constant 0.0 : f32
31+
%C_init = linalg.fill ins(%c0_f32 : f32) outs(%C_dyn : tensor<?x?xf32>) -> tensor<?x?xf32>
3232

3333
%res = linalg.matmul ins(%A_dyn, %B_dyn: tensor<?x?xf32>, tensor<?x?xf32>)
3434
outs(%C_init: tensor<?x?xf32>) -> tensor<?x?xf32>

mlir/test/Integration/Dialect/Transform/match_matmul.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,11 @@ func.func @matmul_simple(%lhs: tensor<10x20xf16>, %rhs: tensor<20x15xf32>) -> te
6363
}
6464

6565
func.func @matmul_with_extra_ops_in_func(%lhs: tensor<10x20xf32>, %rhs: tensor<20x15xf32>) -> tensor<10x15xf32> {
66-
%cst = arith.constant 0.0 : f64
66+
%cst = arith.constant 0.0 : f32
6767
%empty = tensor.empty() : tensor<10x15xf32>
6868

6969
// expected-remark @below {{fill}}
70-
%fill = linalg.fill ins(%cst : f64) outs(%empty : tensor<10x15xf32>) -> tensor<10x15xf32>
70+
%fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<10x15xf32>) -> tensor<10x15xf32>
7171

7272
%real_lhs = linalg.mul
7373
ins(%lhs, %lhs : tensor<10x20xf32>, tensor<10x20xf32>) outs(%lhs : tensor<10x20xf32>) -> tensor<10x20xf32>

0 commit comments

Comments
 (0)