Skip to content

Commit 0e06e61

Browse files
committed
Temp enhance the remove layout implementation to reduce the duplicated values with different layout in scf.for.
Signed-off-by: Lu,Chengjun <chengjun.lu@intel.com>
1 parent 37866dc commit 0e06e61

File tree

3 files changed

+75
-34
lines changed

3 files changed

+75
-34
lines changed

third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ LogicalResult getConvertBackwardSlice(
5050
DenseMap<Value, Attribute> &layout,
5151
std::function<bool(Operation *)> stopPropagation = nullptr,
5252
std::function<Value(OpOperand &, Attribute)> getExistingConversion =
53-
nullptr);
53+
nullptr,
54+
bool includeForOp = false);
5455

5556
LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable, StringRef name,
5657
ArrayRef<Type> paramTypes,

third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp

Lines changed: 48 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -158,19 +158,22 @@ class LayoutRematerialization {
158158
getConvertBackwardSlice(OpOperand &root, Attribute rootEncoding,
159159
SetVector<Value> &slice,
160160
DenseMap<Value, Attribute> &layout,
161-
std::function<bool(Operation *)> stopPropagation);
161+
std::function<bool(Operation *)> stopPropagation,
162+
bool includeForOp = false);
162163

163164
LogicalResult getRematerializableSlice(
164165
OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
165166
DenseMap<Value, Attribute> &layout,
166-
std::function<bool(Operation *)> stopPropagation = nullptr);
167+
std::function<bool(Operation *)> stopPropagation = nullptr,
168+
bool includeForOp = false);
167169

168170
private:
169171
void updateRematMapping(SmallVector<std::tuple<Value, Value>> &values);
170172
// Existing tuples of (value, layout) that needs to be updated when recreating
171173
// scf ops. This prevents keeping track of Values that have been delete when
172-
// rewriting slices.
173-
DenseMap<Value, Attribute> mappedValues;
174+
// rewriting slices. The Value maybe mapped to different attributes in remove
175+
// layout.
176+
DenseMap<Value, SmallVector<Attribute>> mappedValues;
174177
// map of the values remat based on encoding.
175178
DenseMap<std::pair<Value, Attribute>, Value> rematMapping;
176179
// DenseMap<std::pair<Operation*, Attribute>, Operation*>
@@ -184,7 +187,11 @@ void LayoutRematerialization::addRematValue(Value old, Attribute encoding,
184187
Value newV) {
185188
LDBG("addRematValue " << old << " encoding " << encoding << " " << newV);
186189
rematMapping[{old, encoding}] = newV;
187-
mappedValues[old] = encoding;
190+
if (mappedValues.contains(old)) {
191+
mappedValues[old].push_back(encoding);
192+
} else {
193+
mappedValues[old] = {encoding};
194+
}
188195
}
189196

190197
// Remove unneeded values now that we are done with the rematMapping.
@@ -989,22 +996,28 @@ void LayoutRematerialization::updateRematMapping(
989996
for (auto [old, newV] : values) {
990997
auto it = mappedValues.find(old);
991998
if (it != mappedValues.end()) {
992-
Attribute encoding = it->second;
993-
auto rematIt = rematMapping.find({old, it->second});
994-
assert(rematIt != rematMapping.end());
995-
Value replacedValue = rematIt->second;
996-
rematMapping.erase(rematIt);
997-
mappedValues.erase(it);
998-
// Loop through the replacement value to find the new version of remat
999-
// value. This should be okay as the number of values should be small.
1000-
for (auto [before, after] : values) {
1001-
if (before == replacedValue) {
1002-
replacedValue = after;
1003-
break;
999+
SmallVector<Attribute> encodings = it->second;
1000+
for (auto encoding : encodings) {
1001+
auto rematIt = rematMapping.find({old, encoding});
1002+
assert(rematIt != rematMapping.end());
1003+
Value replacedValue = rematIt->second;
1004+
rematMapping.erase(rematIt);
1005+
// Loop through the replacement value to find the new version of remat
1006+
// value. This should be okay as the number of values should be small.
1007+
for (auto [before, after] : values) {
1008+
if (before == replacedValue) {
1009+
replacedValue = after;
1010+
break;
1011+
}
10041012
}
1013+
rematMapping[{newV, encoding}] = replacedValue;
1014+
}
1015+
mappedValues.erase(it);
1016+
if (mappedValues.contains(newV)) {
1017+
mappedValues[newV].append(encodings);
1018+
} else {
1019+
mappedValues[newV] = std::move(encodings);
10051020
}
1006-
rematMapping[{newV, encoding}] = replacedValue;
1007-
mappedValues[newV] = encoding;
10081021
}
10091022
}
10101023
}
@@ -1079,6 +1092,7 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
10791092
deadOps.push_back(forOp.getOperation());
10801093
Block &loopBody = *newForOp.getBody();
10811094
for (auto m : argMapping) {
1095+
mapping.map(newForOp.getResult(m.first), newForOp.getResult(m.second));
10821096
mapping.map(forOp.getResult(m.first), newForOp.getResult(m.second));
10831097
int numIndVars = newForOp.getNumInductionVars();
10841098
mapping.map(loopBody.getArgument(m.first + numIndVars),
@@ -1189,8 +1203,12 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
11891203
builder.replaceAllUsesWith(std::get<0>(kv), std::get<1>(kv));
11901204
}
11911205

1192-
for (Operation *op : deadOps)
1193-
opToDelete.insert(op);
1206+
for (Operation *op : deadOps) {
1207+
if (!isa<scf::ForOp>(op))
1208+
opToDelete.insert(op);
1209+
else
1210+
op->erase();
1211+
}
11941212
}
11951213

11961214
void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
@@ -1203,7 +1221,7 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
12031221
LogicalResult LayoutRematerialization::getConvertBackwardSlice(
12041222
OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
12051223
DenseMap<Value, Attribute> &layout,
1206-
std::function<bool(Operation *)> stopPropagation) {
1224+
std::function<bool(Operation *)> stopPropagation, bool includeForOp) {
12071225
// Allow re-using existing conversions for a value. Check dominance of any
12081226
// reusable materializations against the root value. This is sufficient
12091227
// because the conversions are processed in post-order.
@@ -1232,15 +1250,16 @@ LogicalResult LayoutRematerialization::getConvertBackwardSlice(
12321250
};
12331251

12341252
return ttgi::getConvertBackwardSlice(root, slice, rootEncoding, layout,
1235-
stopPropagation, getExistingConversion);
1253+
stopPropagation, getExistingConversion,
1254+
includeForOp);
12361255
}
12371256

12381257
LogicalResult LayoutRematerialization::getRematerializableSlice(
12391258
OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
12401259
DenseMap<Value, Attribute> &layout,
1241-
std::function<bool(Operation *)> stopPropagation) {
1242-
LogicalResult result = getConvertBackwardSlice(root, rootEncoding, slice,
1243-
layout, stopPropagation);
1260+
std::function<bool(Operation *)> stopPropagation, bool includeForOp) {
1261+
LogicalResult result = getConvertBackwardSlice(
1262+
root, rootEncoding, slice, layout, stopPropagation, includeForOp);
12441263
if (result.failed() || slice.empty())
12451264
return failure();
12461265

@@ -1434,8 +1453,9 @@ void LayoutRematerialization::backwardRematerialization(
14341453
// rematerialized.
14351454
SetVector<Value> slice;
14361455
DenseMap<Value, Attribute> layout;
1437-
LogicalResult result = getRematerializableSlice(
1438-
convertOp.getSrcMutable(), targetType.getEncoding(), slice, layout);
1456+
LogicalResult result = getRematerializableSlice(convertOp.getSrcMutable(),
1457+
targetType.getEncoding(),
1458+
slice, layout, nullptr, true);
14391459
if (result.failed()) {
14401460
LDBG(" getRematerializableSlice failed");
14411461
return;

third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,8 @@ LogicalResult getConvertBackwardSlice(
182182
OpOperand &root, SetVector<Value> &slice, Attribute rootEncoding,
183183
DenseMap<Value, Attribute> &layout,
184184
std::function<bool(Operation *)> stopPropagation,
185-
std::function<Value(OpOperand &, Attribute)> getExistingConversion) {
185+
std::function<Value(OpOperand &, Attribute)> getExistingConversion,
186+
bool includeForOp) {
186187
DenseSet<std::pair<OpOperand *, Attribute>> seen;
187188
SmallVector<std::pair<OpOperand *, Attribute>> queue;
188189

@@ -197,6 +198,12 @@ LogicalResult getConvertBackwardSlice(
197198

198199
auto updateLayout = [&](Value value, Attribute encoding) {
199200
assert(isTensorOrTensorPointerType(value.getType()));
201+
auto tensorType = getRankedTensorType(value.getType());
202+
auto originEncoding = tensorType.getEncoding();
203+
if (originEncoding == encoding) {
204+
return success();
205+
}
206+
200207
slice.insert(value);
201208
Attribute &existing = layout[value];
202209
if (existing && existing != encoding)
@@ -211,10 +218,7 @@ LogicalResult getConvertBackwardSlice(
211218
queue.pop_back();
212219
if (!isTensorOrTensorPointerType(currentValue.getType()))
213220
continue;
214-
// Skip propagating through for op results for now.
215-
// TODO: enable this based on needs.
216-
if (currentValue.getDefiningOp<scf::ForOp>())
217-
return failure();
221+
218222
if (failed(updateLayout(currentValue, encoding)))
219223
return failure();
220224

@@ -226,6 +230,22 @@ LogicalResult getConvertBackwardSlice(
226230
currentValue = existing;
227231
}
228232

233+
if (auto forOp = currentValue.getDefiningOp<scf::ForOp>()) {
234+
if (!includeForOp)
235+
return failure();
236+
if (stopPropagation && stopPropagation(forOp))
237+
continue;
238+
unsigned argIdx = mlir::cast<OpResult>(currentValue).getResultNumber();
239+
int numIndVars = forOp.getNumInductionVars();
240+
Block &loopBody = *forOp.getBody();
241+
auto blockArg = loopBody.getArgument(argIdx + numIndVars);
242+
OpOperand *initOperand = forOp.getTiedLoopInit(blockArg);
243+
OpOperand &yieldOperand = loopBody.getTerminator()->getOpOperand(argIdx);
244+
enqueue(*initOperand, encoding);
245+
enqueue(yieldOperand, encoding);
246+
continue;
247+
}
248+
229249
if (auto ifOp = currentValue.getDefiningOp<scf::IfOp>()) {
230250
if (stopPropagation && stopPropagation(ifOp))
231251
continue;

0 commit comments

Comments
 (0)