33#include " mlir/Analysis/SliceAnalysis.h"
44#include " mlir/Dialect/SCF/IR/SCF.h"
55#include " mlir/IR/BuiltinAttributes.h"
6+ #include " mlir/IR/Dominance.h"
67#include " mlir/IR/IRMapping.h"
78#include " mlir/IR/Matchers.h"
89#include " mlir/IR/PatternMatch.h"
@@ -116,17 +117,15 @@ class LayoutPropagation {
116117class LayoutRematerialization {
117118public:
118119 LayoutRematerialization (FuncOp F) : funcOp(F) {}
120+
119121 // Map the original value to the remat'ed one.
120122 void addRematValue (Value old, Attribute encoding, Value newV);
121- bool hasRematValue (Value value, Attribute encoding) {
122- return rematMapping.contains ({value, encoding});
123- }
124- // Return the remat'ed value in the given encoding.
125- Value getRematValue (Value value, Attribute encoding) {
126- auto it = rematMapping.find ({value, encoding});
127- assert (it != rematMapping.end ());
128- return it->second ;
123+ // Get the remat'ed value in the given encoding, if one already exists and
124+ // is different then the layout conversion root.
125+ Value getRematValue (Value value, Attribute encoding) const {
126+ return rematMapping.lookup ({value, encoding});
129127 }
128+
130129 void cleanup ();
131130 void backwardRematerialization ();
132131 void backwardRematerialization (ConvertLayoutOp convertOp);
@@ -137,6 +136,11 @@ class LayoutRematerialization {
137136 void rewriteSlice (SetVector<Value> &slice, DenseMap<Value, Attribute> &layout,
138137 ConvertLayoutOp convertOp);
139138
139+ LogicalResult getRematerializableSlice (
140+ OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
141+ DenseMap<Value, Attribute> &layout,
142+ std::function<bool (Operation *)> stopPropagation = nullptr);
143+
140144private:
141145 void updateRematMapping (SmallVector<std::tuple<Value, Value>> &values);
142146 // Existing tuples of (value, layout) that needs to be updated when recreating
@@ -148,6 +152,7 @@ class LayoutRematerialization {
148152 // DenseMap<std::pair<Operation*, Attribute>, Operation*>
149153 SetVector<Operation *> opToDelete;
150154 FuncOp funcOp;
155+ DominanceInfo domInfo;
151156};
152157
153158void LayoutRematerialization::addRematValue (Value old, Attribute encoding,
@@ -778,8 +783,8 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
778783 auto layoutIt = layout.find (v);
779784 assert (layoutIt != layout.end ());
780785 // If we already have a remat value for this value, use it.
781- if (hasRematValue (v, layoutIt->second )) {
782- mapping.map (v, getRematValue (v, layoutIt-> second ) );
786+ if (Value remat = getRematValue (v, layoutIt->second )) {
787+ mapping.map (v, remat );
783788 valuesWithExistingRemat.insert (v);
784789 continue ;
785790 }
@@ -940,12 +945,36 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
940945 rewriteSlice (slice, layout, convertOp, mapping);
941946}
942947
943- LogicalResult getRematerializableSlice (
944- Value root, Attribute rootEncoding, SetVector<Value> &slice,
948+ LogicalResult LayoutRematerialization:: getRematerializableSlice (
949+ OpOperand & root, Attribute rootEncoding, SetVector<Value> &slice,
945950 DenseMap<Value, Attribute> &layout,
946- std::function<bool (Operation *)> stopPropagation = nullptr) {
947- LogicalResult result = getConvertBackwardSlice (root, slice, rootEncoding,
948- layout, stopPropagation);
951+ std::function<bool (Operation *)> stopPropagation) {
952+ // Allow re-using existing conversions for a value. Check dominance of any
953+ // re-usable materializations against the root value. This is sufficient
954+ // because the conversions are processed in post-order.
955+ auto getExistingConversion = [&](OpOperand &value, Attribute encoding) {
956+ Value remat = getRematValue (value.get (), encoding);
957+ if (!remat)
958+ return Value ();
959+ // `value` can be replaced with an existing rematerialization if it
960+ // dominates the current use of value.
961+ Operation *user = value.getOwner ();
962+ if (domInfo.properlyDominates (remat, user)) {
963+ return remat;
964+ }
965+ // Alternatively, if the current use can be sunk below the existing
966+ // rematerialization, then it is okay to use as well. E.g. the current use
967+ // is a conversion that will be folded away when its result is
968+ // rematerialized.
969+ if (isa<ConvertLayoutOp>(user) && remat.getDefiningOp () &&
970+ domInfo.properlyDominates (user, remat.getDefiningOp ())) {
971+ return remat;
972+ }
973+ return Value ();
974+ };
975+ LogicalResult result =
976+ getConvertBackwardSlice (root, slice, rootEncoding, layout,
977+ stopPropagation, getExistingConversion);
949978 if (result.failed () || slice.empty ())
950979 return failure ();
951980
@@ -966,6 +995,12 @@ void LayoutRematerialization::backwardRematerialization() {
966995 [&](ConvertLayoutOp convertOp) { convertOps.push_back (convertOp); });
967996 for (ConvertLayoutOp convertOp : convertOps) {
968997 backwardRematerialization (convertOp);
998+ if (!opToDelete.contains (convertOp)) {
999+ // If the conversion didn't get removed, consider it for re-use in future
1000+ // backward slices.
1001+ addRematValue (convertOp.getSrc (), convertOp.getType ().getEncoding (),
1002+ convertOp.getResult ());
1003+ }
9691004 }
9701005}
9711006
@@ -976,6 +1011,12 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast() {
9761011 [&](ConvertLayoutOp convertOp) { convertOps.push_back (convertOp); });
9771012 for (ConvertLayoutOp convertOp : convertOps) {
9781013 hoistConvertOnTopOfExtOrBroadcast (convertOp);
1014+ if (!opToDelete.contains (convertOp)) {
1015+ // If the conversion didn't get removed, consider it for re-use in future
1016+ // backward slices.
1017+ addRematValue (convertOp.getSrc (), convertOp.getType ().getEncoding (),
1018+ convertOp.getResult ());
1019+ }
9791020 }
9801021}
9811022
@@ -988,14 +1029,14 @@ void LayoutRematerialization::backwardRematerialization(
9881029 // careful with the heuristics for both correctness and perf
9891030 if (isa<DotOperandEncodingAttr, LinearEncodingAttr>(targetType.getEncoding ()))
9901031 return ;
991- Value oldV = convertOp-> getOperand ( 0 );
1032+ Value oldV = convertOp. getSrc ( );
9921033 LDBG (" check backward remat with source " << oldV << " encoding "
9931034 << targetType.getEncoding ());
9941035 // Check to see if there are existing remat'ed values for the pair of oldValue
995- // and encoding.
996- if (hasRematValue (oldV, targetType.getEncoding ())) {
1036+ // and encoding. Make sure it dominates the current conversion.
1037+ Value newV = getRematValue (oldV, targetType.getEncoding ());
1038+ if (newV && domInfo.properlyDominates (newV, convertOp)) {
9971039 // Replace it with the remat'ed value.
998- Value newV = getRematValue (oldV, targetType.getEncoding ());
9991040 convertOp.replaceAllUsesWith (newV);
10001041 opToDelete.insert (convertOp);
10011042 LDBG (" found remat'ed value" << newV);
@@ -1007,7 +1048,7 @@ void LayoutRematerialization::backwardRematerialization(
10071048 SetVector<Value> slice;
10081049 DenseMap<Value, Attribute> layout;
10091050 LogicalResult result = getRematerializableSlice (
1010- convertOp.getSrc (), targetType.getEncoding (), slice, layout);
1051+ convertOp.getSrcMutable (), targetType.getEncoding (), slice, layout);
10111052 if (result.failed ()) {
10121053 LDBG (" getRematerializableSlice failed" );
10131054 return ;
@@ -1050,9 +1091,9 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast(
10501091 // 1. Take a backward slice of all the tensor dependencies.
10511092 SetVector<Value> slice;
10521093 DenseMap<Value, Attribute> layout;
1053- LogicalResult result =
1054- getRematerializableSlice ( convertOp.getSrc (), targetType.getEncoding (),
1055- slice, layout, isExtOrBroadcastOp);
1094+ LogicalResult result = getRematerializableSlice (
1095+ convertOp.getSrcMutable (), targetType.getEncoding (), slice, layout ,
1096+ isExtOrBroadcastOp);
10561097 if (result.failed ())
10571098 return ;
10581099
@@ -1070,7 +1111,7 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast(
10701111 if (!srcEncoding)
10711112 return ;
10721113 LogicalResult result = getRematerializableSlice (
1073- op->getOperand (0 ), srcEncoding, tempSlice, tempLayout);
1114+ op->getOpOperand (0 ), srcEncoding, tempSlice, tempLayout);
10741115 // If we can rematerialize the rest of the ext slice we can ignore this
10751116 // ext as it won't need a convert.
10761117 if (result.succeeded ()) {
0 commit comments