Skip to content

Commit b5cb207

Browse files
ergawymemfrob
authored andcommitted
[MLIR][LinAlg] Start detensoring implementation.
This commit is the first baby step towards detensoring in linalg-on-tensors. Detensoring is the process through which a tensor value is convereted to one or potentially more primitive value(s). During this process, operations with such detensored operands are also converted to an equivalen form that works on primitives. The detensoring process is driven by linalg-on-tensor ops. In particular, a linalg-on-tensor op is checked to see whether *all* its operands can be detensored. If so, those operands are converted to thier primitive counterparts and the linalg op is replaced by an equivalent op that takes those new primitive values as operands. This works towards handling github/iree-org/iree#1159. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D96271
1 parent 97b367a commit b5cb207

File tree

5 files changed

+309
-0
lines changed

5 files changed

+309
-0
lines changed

mlir/include/mlir/Dialect/Linalg/Passes.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ void populateElementwiseToLinalgConversionPatterns(
5959
/// operations.
6060
std::unique_ptr<OperationPass<FuncOp>> createLinalgGeneralizationPass();
6161

62+
/// Create a pass to convert Linalg operations to equivalent operations that
63+
/// work on primitive types, if possible.
64+
std::unique_ptr<Pass> createLinalgDetensorizePass();
65+
6266
/// Patterns to fold an expanding (collapsing) tensor_reshape operation with its
6367
/// producer (consumer) generic operation by expanding the dimensionality of the
6468
/// loop in the generic op.

mlir/include/mlir/Dialect/Linalg/Passes.td

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,4 +136,28 @@ def LinalgGeneralization : FunctionPass<"linalg-generalize-named-ops"> {
136136
let dependentDialects = ["linalg::LinalgDialect"];
137137
}
138138

139+
def LinalgDetensorize : FunctionPass<"linalg-detensorize"> {
140+
let summary = "Detensorize linalg ops";
141+
let constructor = "mlir::createLinalgDetensorizePass()";
142+
let dependentDialects = [];
143+
144+
let description = [{
145+
Detensoring is the process through which a tensor value is convereted to one
146+
or potentially more primitive value(s). During this process, operations with
147+
such detensored operands are also converted to an equivalent form that works
148+
on primitives.
149+
150+
The detensoring process is driven by linalg-on-tensor ops. In particular, a
151+
linalg-on-tensor op is checked to see whether *all* its operands can be
152+
detensored. If so, those operands are converted to their primitive
153+
counterparts and the linalg op is replaced by an equivalent op that takes
154+
those new primitive values as operands. Therefore, the detensoring process
155+
can be divided into 2 main logical phases:
156+
157+
1. Detect/match an op that can be detensored.
158+
2. Detensor the operands of the op and replace it with a primitive
159+
equivalent.
160+
}];
161+
}
162+
139163
#endif // MLIR_DIALECT_LINALG_PASSES

mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
add_mlir_dialect_library(MLIRLinalgTransforms
22
Bufferize.cpp
33
CodegenStrategy.cpp
4+
Detensorize.cpp
45
DropUnitDims.cpp
56
ElementwiseToLinalg.cpp
67
Fusion.cpp
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
//===- Detensorize.cpp - Linalg transformations as patterns ----------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "PassDetail.h"
10+
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
11+
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
12+
#include "mlir/Dialect/Linalg/Passes.h"
13+
#include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
14+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
15+
#include "mlir/IR/OpDefinition.h"
16+
#include "mlir/Transforms/DialectConversion.h"
17+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
18+
#include <iterator>
19+
#include <memory>
20+
21+
using namespace mlir;
22+
using namespace mlir::linalg;
23+
24+
namespace {
25+
/// Defines the criteria a TensorType must follow in order to be considered
26+
/// "detensorable".
27+
///
28+
/// NOTE: For now, only 0-D are supported.
29+
///
30+
/// Returns true if tensorType can be detensored.
31+
bool canBeDetensored(TensorType tensorType) {
32+
return tensorType.hasRank() && tensorType.getRank() == 0;
33+
}
34+
35+
/// A conversion patttern for detensoring `linalg.generic` ops.
36+
class DetensorizeGenericOp : public OpConversionPattern<GenericOp> {
37+
public:
38+
using OpConversionPattern::OpConversionPattern;
39+
LogicalResult
40+
matchAndRewrite(GenericOp op, ArrayRef<Value> operands,
41+
ConversionPatternRewriter &rewriter) const override {
42+
Block *originalBlock = op->getBlock();
43+
44+
// Gather some information about the op before inling its region.
45+
Block *opEntryBlock = &*op.region().begin();
46+
YieldOp yieldOp = dyn_cast<YieldOp>(op.region().back().getTerminator());
47+
48+
// Split the op's region before the op. This way, we have a clear insertion
49+
// point in which the op can be inlined.
50+
Block *newBlock = originalBlock->splitBlock(op);
51+
rewriter.inlineRegionBefore(op.region(), newBlock);
52+
// Now that op's region is inlined, the operands of its YieldOp are mapped
53+
// to the materialized target values. Therefore, we can replace the op's
54+
// uses with those of its YielOp's operands.
55+
rewriter.replaceOp(op, yieldOp->getOperands());
56+
57+
// No need for these intermediate blocks, merge them into 1.
58+
rewriter.mergeBlocks(opEntryBlock, originalBlock, operands);
59+
rewriter.mergeBlocks(newBlock, originalBlock, {});
60+
61+
rewriter.eraseOp(&*Block::iterator(yieldOp));
62+
63+
return success();
64+
}
65+
};
66+
67+
class DetensorizeTypeConverter : public TypeConverter {
68+
public:
69+
DetensorizeTypeConverter() {
70+
addConversion([](Type type) { return type; });
71+
72+
// A TensorType that can be detensored, is converted to the underlying
73+
// element type.
74+
addConversion([](TensorType tensorType) -> Type {
75+
if (canBeDetensored(tensorType))
76+
return tensorType.getElementType();
77+
78+
return tensorType;
79+
});
80+
81+
// A tensor value is detensoried by extracting its element(s).
82+
addTargetMaterialization([](OpBuilder &builder, Type type,
83+
ValueRange inputs, Location loc) -> Value {
84+
return builder.create<tensor::ExtractOp>(loc, inputs[0], ValueRange{});
85+
});
86+
87+
// A detensored value is converted back by creating a new tensor from its
88+
// element(s).
89+
addSourceMaterialization([](OpBuilder &builder, Type type,
90+
ValueRange inputs, Location loc) -> Value {
91+
auto createNewTensorOp = builder.create<tensor::FromElementsOp>(
92+
loc, inputs[0].getType(), inputs[0]);
93+
94+
// FromElementsOp results in a tensor<1xdtype>, we need to reshape that to
95+
// a tensor<dtype> instead.
96+
return builder.create<linalg::TensorReshapeOp>(
97+
loc, type, createNewTensorOp, ArrayRef<ReassociationExprs>{});
98+
});
99+
}
100+
};
101+
102+
/// Canonicalizes the pattern of the form
103+
///
104+
/// %tensor = tensor.from_elements(%element) : (i32) -> tensor<1xi32>
105+
/// %reshaped_tensor = linalg.tensor_reshape %tensor [] : tensor<1xi32> into
106+
/// tensor<i32>
107+
/// %extracted_element = tensor.extract %reshaped_tensor[] : tensor<i32>
108+
///
109+
/// to just %element.
110+
struct ExtractFromReshapeFromElements
111+
: public OpRewritePattern<tensor::ExtractOp> {
112+
using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
113+
114+
LogicalResult matchAndRewrite(tensor::ExtractOp extract,
115+
PatternRewriter &rewriter) const final {
116+
if (extract.indices().size() != 0)
117+
return failure();
118+
119+
auto tensorReshape = extract.tensor().getDefiningOp<TensorReshapeOp>();
120+
if (tensorReshape == nullptr)
121+
return failure();
122+
123+
auto tensorFromElements =
124+
tensorReshape.getOperand()
125+
.getDefiningOp<mlir::tensor::FromElementsOp>();
126+
if (tensorFromElements == nullptr)
127+
return failure();
128+
129+
rewriter.replaceOp(extract, tensorFromElements.getOperand(0));
130+
return success();
131+
}
132+
};
133+
134+
/// @see LinalgDetensorize in Linalg/Passes.td for more details.
135+
struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
136+
void runOnFunction() override {
137+
auto *context = &getContext();
138+
DetensorizeTypeConverter typeConverter;
139+
OwningRewritePatternList patterns;
140+
ConversionTarget target(*context);
141+
142+
target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; });
143+
target.addLegalDialect<linalg::LinalgDialect>();
144+
target.addDynamicallyLegalOp<GenericOp>([&](GenericOp op) {
145+
// If any of the operands or results cannot be detensored, the op is
146+
// considered legal and won't be detensored.
147+
return llvm::any_of(
148+
op.getShapedOperandTypes(), [](ShapedType shapedType) {
149+
assert(shapedType.isa<TensorType>());
150+
return !canBeDetensored(shapedType.cast<TensorType>());
151+
});
152+
});
153+
154+
patterns.insert<DetensorizeGenericOp>(typeConverter, context);
155+
156+
if (failed(
157+
applyPartialConversion(getFunction(), target, std::move(patterns))))
158+
signalPassFailure();
159+
160+
OwningRewritePatternList canonPatterns;
161+
canonPatterns.insert<ExtractFromReshapeFromElements>(context);
162+
if (failed(applyPatternsAndFoldGreedily(getFunction(),
163+
std::move(canonPatterns))))
164+
signalPassFailure();
165+
166+
// TODO Properly handle control flow within function boundaries.
167+
}
168+
};
169+
} // namespace
170+
171+
std::unique_ptr<Pass> mlir::createLinalgDetensorizePass() {
172+
return std::make_unique<LinalgDetensorize>();
173+
}
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
// RUN: mlir-opt %s -allow-unregistered-dialect -linalg-detensorize | FileCheck %s
2+
3+
#map = affine_map<() -> ()>
4+
5+
func @detensor_simple(%arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32> attributes {iree.module.export} {
6+
%0 = linalg.init_tensor [] : tensor<f32>
7+
%1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []}
8+
ins(%arg1, %arg2 : tensor<f32>, tensor<f32>)
9+
outs(%0 : tensor<f32>) {
10+
^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors
11+
%2 = addf %arg3, %arg4 : f32
12+
linalg.yield %2 : f32
13+
} -> tensor<f32>
14+
return %1: tensor<f32>
15+
}
16+
// CHECK-LABEL: func @detensor_simple
17+
// CHECK-SAME: (%[[arg1:.*]]: tensor<f32>, %[[arg2:.*]]: tensor<f32>)
18+
// CHECK-DAG: %[[arg1_val:.*]] = tensor.extract %[[arg1]]
19+
// CHECK-DAG: %[[arg2_val:.*]] = tensor.extract %[[arg2]]
20+
// CHECK: %[[detensored_res:.*]] = addf %[[arg1_val]], %[[arg2_val]]
21+
// CHECK: %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res]]
22+
// CHECK: %[[reshaped_tensor_res:.*]] = linalg.tensor_reshape %[[new_tensor_res]]
23+
// CHECK: return %[[reshaped_tensor_res]]
24+
25+
func @detensor_op_sequence(%arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32> attributes {iree.module.export} {
26+
%0 = linalg.init_tensor [] : tensor<f32>
27+
%1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []}
28+
ins(%arg1, %arg2 : tensor<f32>, tensor<f32>)
29+
outs(%0 : tensor<f32>) {
30+
^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors
31+
%2 = addf %arg3, %arg4 : f32
32+
linalg.yield %2 : f32
33+
} -> tensor<f32>
34+
35+
%3 = linalg.init_tensor [] : tensor<f32>
36+
%4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []}
37+
ins(%arg1, %1 : tensor<f32>, tensor<f32>)
38+
outs(%3 : tensor<f32>) {
39+
^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors
40+
%5 = mulf %arg3, %arg4 : f32
41+
linalg.yield %5 : f32
42+
} -> tensor<f32>
43+
44+
%6 = linalg.init_tensor [] : tensor<f32>
45+
%7 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []}
46+
ins(%1, %4 : tensor<f32>, tensor<f32>)
47+
outs(%6 : tensor<f32>) {
48+
^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors
49+
%5 = divf %arg3, %arg4 : f32
50+
linalg.yield %5 : f32
51+
} -> tensor<f32>
52+
53+
return %7: tensor<f32>
54+
}
55+
// CHECK-LABEL: func @detensor_op_sequence
56+
// CHECK-SAME: (%[[arg1:.*]]: tensor<f32>, %[[arg2:.*]]: tensor<f32>)
57+
// CHECK-DAG: %[[arg1_val:.*]] = tensor.extract %[[arg1]]
58+
// CHECK-DAG: %[[arg2_val:.*]] = tensor.extract %[[arg2]]
59+
// CHECK: %[[detensored_res:.*]] = addf %[[arg1_val]], %[[arg2_val]]
60+
// CHECK-DAG: %[[arg1_val2:.*]] = tensor.extract %[[arg1]]
61+
// CHECK: %[[detensored_res2:.*]] = mulf %[[arg1_val2]], %[[detensored_res]]
62+
// CHECK: %[[detensored_res3:.*]] = divf %[[detensored_res]], %[[detensored_res2]]
63+
// CHECK: %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res3]]
64+
// CHECK: %[[reshaped_tensor_res:.*]] = linalg.tensor_reshape %[[new_tensor_res]]
65+
// CHECK: return %[[reshaped_tensor_res]]
66+
67+
func @detensor_multiple_ops(%arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32> attributes {iree.module.export} {
68+
%0 = linalg.init_tensor [] : tensor<f32>
69+
%1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []}
70+
ins(%arg1, %arg2 : tensor<f32>, tensor<f32>)
71+
outs(%0 : tensor<f32>) {
72+
^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors
73+
%2 = addf %arg3, %arg4 : f32
74+
%3 = mulf %2, %arg4 : f32
75+
linalg.yield %3 : f32
76+
} -> tensor<f32>
77+
return %1: tensor<f32>
78+
}
79+
// CHECK-LABEL: func @detensor_multiple_ops
80+
// CHECK-SAME: (%[[arg1:.*]]: tensor<f32>, %[[arg2:.*]]: tensor<f32>)
81+
// CHECK-DAG: %[[arg1_val:.*]] = tensor.extract %[[arg1]]
82+
// CHECK-DAG: %[[arg2_val:.*]] = tensor.extract %[[arg2]]
83+
// CHECK: %[[detensored_res:.*]] = addf %[[arg1_val]], %[[arg2_val]]
84+
// CHECK: %[[detensored_res2:.*]] = mulf %[[detensored_res]], %[[arg2_val]]
85+
// CHECK: %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res2]]
86+
// CHECK: %[[reshaped_tensor_res:.*]] = linalg.tensor_reshape %[[new_tensor_res]]
87+
// CHECK: return %[[reshaped_tensor_res]]
88+
89+
func @detensor_foreign_op(%arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32> attributes {iree.module.export} {
90+
%0 = linalg.init_tensor [] : tensor<f32>
91+
%1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []}
92+
ins(%arg1, %arg2 : tensor<f32>, tensor<f32>)
93+
outs(%0 : tensor<f32>) {
94+
^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors
95+
%2 = "foreign.do_something"(%arg3, %arg4) {} : (f32, f32) -> f32
96+
linalg.yield %2 : f32
97+
} -> tensor<f32>
98+
return %1: tensor<f32>
99+
}
100+
// CHECK-LABEL: func @detensor_foreign_op
101+
// CHECK-SAME: (%[[arg1:.*]]: tensor<f32>, %[[arg2:.*]]: tensor<f32>)
102+
// CHECK-DAG: %[[arg1_val:.*]] = tensor.extract %[[arg1]]
103+
// CHECK-DAG: %[[arg2_val:.*]] = tensor.extract %[[arg2]]
104+
// CHECK: %[[detensored_res:.*]] = "foreign.do_something"(%[[arg1_val]], %[[arg2_val]])
105+
// CHECK: %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res]]
106+
// CHECK: %[[reshaped_tensor_res:.*]] = linalg.tensor_reshape %[[new_tensor_res]]
107+
// CHECK: return %[[reshaped_tensor_res]]

0 commit comments

Comments
 (0)