|
| 1 | +//===- Detensorize.cpp - Linalg transformations as patterns ----------===// |
| 2 | +// |
| 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +// |
| 7 | +//===----------------------------------------------------------------------===// |
| 8 | + |
| 9 | +#include "PassDetail.h" |
| 10 | +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" |
| 11 | +#include "mlir/Dialect/Linalg/IR/LinalgTypes.h" |
| 12 | +#include "mlir/Dialect/Linalg/Passes.h" |
| 13 | +#include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" |
| 14 | +#include "mlir/Dialect/Tensor/IR/Tensor.h" |
| 15 | +#include "mlir/IR/OpDefinition.h" |
| 16 | +#include "mlir/Transforms/DialectConversion.h" |
| 17 | +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 18 | +#include <iterator> |
| 19 | +#include <memory> |
| 20 | + |
| 21 | +using namespace mlir; |
| 22 | +using namespace mlir::linalg; |
| 23 | + |
| 24 | +namespace { |
| 25 | +/// Defines the criteria a TensorType must follow in order to be considered |
| 26 | +/// "detensorable". |
| 27 | +/// |
| 28 | +/// NOTE: For now, only 0-D are supported. |
| 29 | +/// |
| 30 | +/// Returns true if tensorType can be detensored. |
| 31 | +bool canBeDetensored(TensorType tensorType) { |
| 32 | + return tensorType.hasRank() && tensorType.getRank() == 0; |
| 33 | +} |
| 34 | + |
| 35 | +/// A conversion patttern for detensoring `linalg.generic` ops. |
| 36 | +class DetensorizeGenericOp : public OpConversionPattern<GenericOp> { |
| 37 | +public: |
| 38 | + using OpConversionPattern::OpConversionPattern; |
| 39 | + LogicalResult |
| 40 | + matchAndRewrite(GenericOp op, ArrayRef<Value> operands, |
| 41 | + ConversionPatternRewriter &rewriter) const override { |
| 42 | + Block *originalBlock = op->getBlock(); |
| 43 | + |
| 44 | + // Gather some information about the op before inling its region. |
| 45 | + Block *opEntryBlock = &*op.region().begin(); |
| 46 | + YieldOp yieldOp = dyn_cast<YieldOp>(op.region().back().getTerminator()); |
| 47 | + |
| 48 | + // Split the op's region before the op. This way, we have a clear insertion |
| 49 | + // point in which the op can be inlined. |
| 50 | + Block *newBlock = originalBlock->splitBlock(op); |
| 51 | + rewriter.inlineRegionBefore(op.region(), newBlock); |
| 52 | + // Now that op's region is inlined, the operands of its YieldOp are mapped |
| 53 | + // to the materialized target values. Therefore, we can replace the op's |
| 54 | + // uses with those of its YielOp's operands. |
| 55 | + rewriter.replaceOp(op, yieldOp->getOperands()); |
| 56 | + |
| 57 | + // No need for these intermediate blocks, merge them into 1. |
| 58 | + rewriter.mergeBlocks(opEntryBlock, originalBlock, operands); |
| 59 | + rewriter.mergeBlocks(newBlock, originalBlock, {}); |
| 60 | + |
| 61 | + rewriter.eraseOp(&*Block::iterator(yieldOp)); |
| 62 | + |
| 63 | + return success(); |
| 64 | + } |
| 65 | +}; |
| 66 | + |
| 67 | +class DetensorizeTypeConverter : public TypeConverter { |
| 68 | +public: |
| 69 | + DetensorizeTypeConverter() { |
| 70 | + addConversion([](Type type) { return type; }); |
| 71 | + |
| 72 | + // A TensorType that can be detensored, is converted to the underlying |
| 73 | + // element type. |
| 74 | + addConversion([](TensorType tensorType) -> Type { |
| 75 | + if (canBeDetensored(tensorType)) |
| 76 | + return tensorType.getElementType(); |
| 77 | + |
| 78 | + return tensorType; |
| 79 | + }); |
| 80 | + |
| 81 | + // A tensor value is detensoried by extracting its element(s). |
| 82 | + addTargetMaterialization([](OpBuilder &builder, Type type, |
| 83 | + ValueRange inputs, Location loc) -> Value { |
| 84 | + return builder.create<tensor::ExtractOp>(loc, inputs[0], ValueRange{}); |
| 85 | + }); |
| 86 | + |
| 87 | + // A detensored value is converted back by creating a new tensor from its |
| 88 | + // element(s). |
| 89 | + addSourceMaterialization([](OpBuilder &builder, Type type, |
| 90 | + ValueRange inputs, Location loc) -> Value { |
| 91 | + auto createNewTensorOp = builder.create<tensor::FromElementsOp>( |
| 92 | + loc, inputs[0].getType(), inputs[0]); |
| 93 | + |
| 94 | + // FromElementsOp results in a tensor<1xdtype>, we need to reshape that to |
| 95 | + // a tensor<dtype> instead. |
| 96 | + return builder.create<linalg::TensorReshapeOp>( |
| 97 | + loc, type, createNewTensorOp, ArrayRef<ReassociationExprs>{}); |
| 98 | + }); |
| 99 | + } |
| 100 | +}; |
| 101 | + |
| 102 | +/// Canonicalizes the pattern of the form |
| 103 | +/// |
| 104 | +/// %tensor = tensor.from_elements(%element) : (i32) -> tensor<1xi32> |
| 105 | +/// %reshaped_tensor = linalg.tensor_reshape %tensor [] : tensor<1xi32> into |
| 106 | +/// tensor<i32> |
| 107 | +/// %extracted_element = tensor.extract %reshaped_tensor[] : tensor<i32> |
| 108 | +/// |
| 109 | +/// to just %element. |
| 110 | +struct ExtractFromReshapeFromElements |
| 111 | + : public OpRewritePattern<tensor::ExtractOp> { |
| 112 | + using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern; |
| 113 | + |
| 114 | + LogicalResult matchAndRewrite(tensor::ExtractOp extract, |
| 115 | + PatternRewriter &rewriter) const final { |
| 116 | + if (extract.indices().size() != 0) |
| 117 | + return failure(); |
| 118 | + |
| 119 | + auto tensorReshape = extract.tensor().getDefiningOp<TensorReshapeOp>(); |
| 120 | + if (tensorReshape == nullptr) |
| 121 | + return failure(); |
| 122 | + |
| 123 | + auto tensorFromElements = |
| 124 | + tensorReshape.getOperand() |
| 125 | + .getDefiningOp<mlir::tensor::FromElementsOp>(); |
| 126 | + if (tensorFromElements == nullptr) |
| 127 | + return failure(); |
| 128 | + |
| 129 | + rewriter.replaceOp(extract, tensorFromElements.getOperand(0)); |
| 130 | + return success(); |
| 131 | + } |
| 132 | +}; |
| 133 | + |
| 134 | +/// @see LinalgDetensorize in Linalg/Passes.td for more details. |
| 135 | +struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> { |
| 136 | + void runOnFunction() override { |
| 137 | + auto *context = &getContext(); |
| 138 | + DetensorizeTypeConverter typeConverter; |
| 139 | + OwningRewritePatternList patterns; |
| 140 | + ConversionTarget target(*context); |
| 141 | + |
| 142 | + target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; }); |
| 143 | + target.addLegalDialect<linalg::LinalgDialect>(); |
| 144 | + target.addDynamicallyLegalOp<GenericOp>([&](GenericOp op) { |
| 145 | + // If any of the operands or results cannot be detensored, the op is |
| 146 | + // considered legal and won't be detensored. |
| 147 | + return llvm::any_of( |
| 148 | + op.getShapedOperandTypes(), [](ShapedType shapedType) { |
| 149 | + assert(shapedType.isa<TensorType>()); |
| 150 | + return !canBeDetensored(shapedType.cast<TensorType>()); |
| 151 | + }); |
| 152 | + }); |
| 153 | + |
| 154 | + patterns.insert<DetensorizeGenericOp>(typeConverter, context); |
| 155 | + |
| 156 | + if (failed( |
| 157 | + applyPartialConversion(getFunction(), target, std::move(patterns)))) |
| 158 | + signalPassFailure(); |
| 159 | + |
| 160 | + OwningRewritePatternList canonPatterns; |
| 161 | + canonPatterns.insert<ExtractFromReshapeFromElements>(context); |
| 162 | + if (failed(applyPatternsAndFoldGreedily(getFunction(), |
| 163 | + std::move(canonPatterns)))) |
| 164 | + signalPassFailure(); |
| 165 | + |
| 166 | + // TODO Properly handle control flow within function boundaries. |
| 167 | + } |
| 168 | +}; |
| 169 | +} // namespace |
| 170 | + |
| 171 | +std::unique_ptr<Pass> mlir::createLinalgDetensorizePass() { |
| 172 | + return std::make_unique<LinalgDetensorize>(); |
| 173 | +} |
0 commit comments