@@ -158,19 +158,22 @@ class LayoutRematerialization {
158158 getConvertBackwardSlice (OpOperand &root, Attribute rootEncoding,
159159 SetVector<Value> &slice,
160160 DenseMap<Value, Attribute> &layout,
161- std::function<bool (Operation *)> stopPropagation);
161+ std::function<bool (Operation *)> stopPropagation,
162+ bool includeForOp = false );
162163
163164 LogicalResult getRematerializableSlice (
164165 OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
165166 DenseMap<Value, Attribute> &layout,
166- std::function<bool (Operation *)> stopPropagation = nullptr);
167+ std::function<bool (Operation *)> stopPropagation = nullptr,
168+ bool includeForOp = false);
167169
168170private:
169171 void updateRematMapping (SmallVector<std::tuple<Value, Value>> &values);
170172 // Existing tuples of (value, layout) that needs to be updated when recreating
171173 // scf ops. This prevents keeping track of Values that have been delete when
172- // rewriting slices.
173- DenseMap<Value, Attribute> mappedValues;
174+ // rewriting slices. The Value maybe mapped to different attributes in remove
175+ // layout.
176+ DenseMap<Value, SmallVector<Attribute>> mappedValues;
174177 // map of the values remat based on encoding.
175178 DenseMap<std::pair<Value, Attribute>, Value> rematMapping;
176179 // DenseMap<std::pair<Operation*, Attribute>, Operation*>
@@ -184,7 +187,11 @@ void LayoutRematerialization::addRematValue(Value old, Attribute encoding,
184187 Value newV) {
185188 LDBG (" addRematValue " << old << " encoding " << encoding << " " << newV);
186189 rematMapping[{old, encoding}] = newV;
187- mappedValues[old] = encoding;
190+ if (mappedValues.contains (old)) {
191+ mappedValues[old].push_back (encoding);
192+ } else {
193+ mappedValues[old] = {encoding};
194+ }
188195}
189196
190197// Remove unneeded values now that we are done with the rematMapping.
@@ -989,22 +996,28 @@ void LayoutRematerialization::updateRematMapping(
989996 for (auto [old, newV] : values) {
990997 auto it = mappedValues.find (old);
991998 if (it != mappedValues.end ()) {
992- Attribute encoding = it->second ;
993- auto rematIt = rematMapping.find ({old, it->second });
994- assert (rematIt != rematMapping.end ());
995- Value replacedValue = rematIt->second ;
996- rematMapping.erase (rematIt);
997- mappedValues.erase (it);
998- // Loop through the replacement value to find the new version of remat
999- // value. This should be okay as the number of values should be small.
1000- for (auto [before, after] : values) {
1001- if (before == replacedValue) {
1002- replacedValue = after;
1003- break ;
999+ SmallVector<Attribute> encodings = it->second ;
1000+ for (auto encoding : encodings) {
1001+ auto rematIt = rematMapping.find ({old, encoding});
1002+ assert (rematIt != rematMapping.end ());
1003+ Value replacedValue = rematIt->second ;
1004+ rematMapping.erase (rematIt);
1005+ // Loop through the replacement value to find the new version of remat
1006+ // value. This should be okay as the number of values should be small.
1007+ for (auto [before, after] : values) {
1008+ if (before == replacedValue) {
1009+ replacedValue = after;
1010+ break ;
1011+ }
10041012 }
1013+ rematMapping[{newV, encoding}] = replacedValue;
1014+ }
1015+ mappedValues.erase (it);
1016+ if (mappedValues.contains (newV)) {
1017+ mappedValues[newV].append (encodings);
1018+ } else {
1019+ mappedValues[newV] = std::move (encodings);
10051020 }
1006- rematMapping[{newV, encoding}] = replacedValue;
1007- mappedValues[newV] = encoding;
10081021 }
10091022 }
10101023}
@@ -1079,6 +1092,7 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
10791092 deadOps.push_back (forOp.getOperation ());
10801093 Block &loopBody = *newForOp.getBody ();
10811094 for (auto m : argMapping) {
1095+ mapping.map (newForOp.getResult (m.first ), newForOp.getResult (m.second ));
10821096 mapping.map (forOp.getResult (m.first ), newForOp.getResult (m.second ));
10831097 int numIndVars = newForOp.getNumInductionVars ();
10841098 mapping.map (loopBody.getArgument (m.first + numIndVars),
@@ -1189,8 +1203,12 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
11891203 builder.replaceAllUsesWith (std::get<0 >(kv), std::get<1 >(kv));
11901204 }
11911205
1192- for (Operation *op : deadOps)
1193- opToDelete.insert (op);
1206+ for (Operation *op : deadOps) {
1207+ if (!isa<scf::ForOp>(op))
1208+ opToDelete.insert (op);
1209+ else
1210+ op->erase ();
1211+ }
11941212}
11951213
11961214void LayoutRematerialization::rewriteSlice (SetVector<Value> &slice,
@@ -1203,7 +1221,7 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
12031221LogicalResult LayoutRematerialization::getConvertBackwardSlice (
12041222 OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
12051223 DenseMap<Value, Attribute> &layout,
1206- std::function<bool (Operation *)> stopPropagation) {
1224+ std::function<bool (Operation *)> stopPropagation, bool includeForOp ) {
12071225 // Allow re-using existing conversions for a value. Check dominance of any
12081226 // reusable materializations against the root value. This is sufficient
12091227 // because the conversions are processed in post-order.
@@ -1232,15 +1250,16 @@ LogicalResult LayoutRematerialization::getConvertBackwardSlice(
12321250 };
12331251
12341252 return ttgi::getConvertBackwardSlice (root, slice, rootEncoding, layout,
1235- stopPropagation, getExistingConversion);
1253+ stopPropagation, getExistingConversion,
1254+ includeForOp);
12361255}
12371256
12381257LogicalResult LayoutRematerialization::getRematerializableSlice (
12391258 OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
12401259 DenseMap<Value, Attribute> &layout,
1241- std::function<bool (Operation *)> stopPropagation) {
1242- LogicalResult result = getConvertBackwardSlice (root, rootEncoding, slice,
1243- layout, stopPropagation);
1260+ std::function<bool (Operation *)> stopPropagation, bool includeForOp ) {
1261+ LogicalResult result = getConvertBackwardSlice (
1262+ root, rootEncoding, slice, layout, stopPropagation, includeForOp );
12441263 if (result.failed () || slice.empty ())
12451264 return failure ();
12461265
@@ -1434,8 +1453,9 @@ void LayoutRematerialization::backwardRematerialization(
14341453 // rematerialized.
14351454 SetVector<Value> slice;
14361455 DenseMap<Value, Attribute> layout;
1437- LogicalResult result = getRematerializableSlice (
1438- convertOp.getSrcMutable (), targetType.getEncoding (), slice, layout);
1456+ LogicalResult result = getRematerializableSlice (convertOp.getSrcMutable (),
1457+ targetType.getEncoding (),
1458+ slice, layout, nullptr , true );
14391459 if (result.failed ()) {
14401460 LDBG (" getRematerializableSlice failed" );
14411461 return ;
0 commit comments