Skip to content

Commit aa7a897

Browse files
Merge commit '1f8966b53a3ba5c68294c551250438cca54c771f'
2 parents 182fb7f + 1f8966b commit aa7a897

File tree

10 files changed

+691
-99
lines changed

10 files changed

+691
-99
lines changed

bin/triton-tensor-layout.cpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,18 @@ static cl::opt<std::string> TensorStr(
8080
//===--------------------------------------------------------------------===//
8181

8282
LogicalResult layoutPrint(RankedTensorType tensorType, raw_ostream &os) {
83-
// Dispatch to the corresponding dialect helper function to print the layout.
84-
os << triton::gpu::getLayoutStr(tensorType, UseHWPointOfView);
85-
return success();
83+
// DistributedEncodingTrait and SharedEncodingAttr implements the
84+
// toLinearLayout interface.
85+
mlir::Attribute layout = tensorType.getEncoding();
86+
if (isa<mlir::triton::gpu::DistributedEncodingTrait,
87+
mlir::triton::gpu::SharedEncodingAttr>(layout)) {
88+
os << triton::gpu::getLayoutStr(tensorType, UseHWPointOfView);
89+
return success();
90+
}
91+
92+
llvm::errs() << "Unsupported tensor layout attribute: "
93+
<< tensorType.getEncoding() << "\n";
94+
return failure();
8695
}
8796

8897
LogicalResult printLayoutFromFile(MLIRContext *context, StringRef filename,

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,9 +161,11 @@ Operation *cloneWithInferType(mlir::OpBuilder &rewriter, Operation *op,
161161
// Get backward slice of tensor values starting from the root node along with
162162
// encoding propagation.
163163
LogicalResult getConvertBackwardSlice(
164-
Value root, SetVector<Value> &slice, Attribute rootEncoding,
164+
OpOperand &root, SetVector<Value> &slice, Attribute rootEncoding,
165165
DenseMap<Value, Attribute> &layout,
166-
std::function<bool(Operation *)> stopPropagation = nullptr);
166+
std::function<bool(Operation *)> stopPropagation = nullptr,
167+
std::function<Value(OpOperand &, Attribute)> getExistingConversion =
168+
nullptr);
167169

168170
// Populate pattern to remove dead cycles in ForOp.
169171
void populateForOpDeadArgumentElimination(RewritePatternSet &patterns);

lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp

Lines changed: 65 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "mlir/Analysis/SliceAnalysis.h"
44
#include "mlir/Dialect/SCF/IR/SCF.h"
55
#include "mlir/IR/BuiltinAttributes.h"
6+
#include "mlir/IR/Dominance.h"
67
#include "mlir/IR/IRMapping.h"
78
#include "mlir/IR/Matchers.h"
89
#include "mlir/IR/PatternMatch.h"
@@ -116,17 +117,15 @@ class LayoutPropagation {
116117
class LayoutRematerialization {
117118
public:
118119
LayoutRematerialization(FuncOp F) : funcOp(F) {}
120+
119121
// Map the original value to the remat'ed one.
120122
void addRematValue(Value old, Attribute encoding, Value newV);
121-
bool hasRematValue(Value value, Attribute encoding) {
122-
return rematMapping.contains({value, encoding});
123-
}
124-
// Return the remat'ed value in the given encoding.
125-
Value getRematValue(Value value, Attribute encoding) {
126-
auto it = rematMapping.find({value, encoding});
127-
assert(it != rematMapping.end());
128-
return it->second;
123+
// Get the remat'ed value in the given encoding, if one already exists and
124+
// is different then the layout conversion root.
125+
Value getRematValue(Value value, Attribute encoding) const {
126+
return rematMapping.lookup({value, encoding});
129127
}
128+
130129
void cleanup();
131130
void backwardRematerialization();
132131
void backwardRematerialization(ConvertLayoutOp convertOp);
@@ -137,6 +136,11 @@ class LayoutRematerialization {
137136
void rewriteSlice(SetVector<Value> &slice, DenseMap<Value, Attribute> &layout,
138137
ConvertLayoutOp convertOp);
139138

139+
LogicalResult getRematerializableSlice(
140+
OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
141+
DenseMap<Value, Attribute> &layout,
142+
std::function<bool(Operation *)> stopPropagation = nullptr);
143+
140144
private:
141145
void updateRematMapping(SmallVector<std::tuple<Value, Value>> &values);
142146
// Existing tuples of (value, layout) that needs to be updated when recreating
@@ -148,6 +152,7 @@ class LayoutRematerialization {
148152
// DenseMap<std::pair<Operation*, Attribute>, Operation*>
149153
SetVector<Operation *> opToDelete;
150154
FuncOp funcOp;
155+
DominanceInfo domInfo;
151156
};
152157

153158
void LayoutRematerialization::addRematValue(Value old, Attribute encoding,
@@ -778,8 +783,8 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
778783
auto layoutIt = layout.find(v);
779784
assert(layoutIt != layout.end());
780785
// If we already have a remat value for this value, use it.
781-
if (hasRematValue(v, layoutIt->second)) {
782-
mapping.map(v, getRematValue(v, layoutIt->second));
786+
if (Value remat = getRematValue(v, layoutIt->second)) {
787+
mapping.map(v, remat);
783788
valuesWithExistingRemat.insert(v);
784789
continue;
785790
}
@@ -940,12 +945,36 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
940945
rewriteSlice(slice, layout, convertOp, mapping);
941946
}
942947

943-
LogicalResult getRematerializableSlice(
944-
Value root, Attribute rootEncoding, SetVector<Value> &slice,
948+
LogicalResult LayoutRematerialization::getRematerializableSlice(
949+
OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
945950
DenseMap<Value, Attribute> &layout,
946-
std::function<bool(Operation *)> stopPropagation = nullptr) {
947-
LogicalResult result = getConvertBackwardSlice(root, slice, rootEncoding,
948-
layout, stopPropagation);
951+
std::function<bool(Operation *)> stopPropagation) {
952+
// Allow re-using existing conversions for a value. Check dominance of any
953+
// re-usable materializations against the root value. This is sufficient
954+
// because the conversions are processed in post-order.
955+
auto getExistingConversion = [&](OpOperand &value, Attribute encoding) {
956+
Value remat = getRematValue(value.get(), encoding);
957+
if (!remat)
958+
return Value();
959+
// `value` can be replaced with an existing rematerialization if it
960+
// dominates the current use of value.
961+
Operation *user = value.getOwner();
962+
if (domInfo.properlyDominates(remat, user)) {
963+
return remat;
964+
}
965+
// Alternatively, if the current use can be sunk below the existing
966+
// rematerialization, then it is okay to use as well. E.g. the current use
967+
// is a conversion that will be folded away when its result is
968+
// rematerialized.
969+
if (isa<ConvertLayoutOp>(user) && remat.getDefiningOp() &&
970+
domInfo.properlyDominates(user, remat.getDefiningOp())) {
971+
return remat;
972+
}
973+
return Value();
974+
};
975+
LogicalResult result =
976+
getConvertBackwardSlice(root, slice, rootEncoding, layout,
977+
stopPropagation, getExistingConversion);
949978
if (result.failed() || slice.empty())
950979
return failure();
951980

@@ -966,6 +995,12 @@ void LayoutRematerialization::backwardRematerialization() {
966995
[&](ConvertLayoutOp convertOp) { convertOps.push_back(convertOp); });
967996
for (ConvertLayoutOp convertOp : convertOps) {
968997
backwardRematerialization(convertOp);
998+
if (!opToDelete.contains(convertOp)) {
999+
// If the conversion didn't get removed, consider it for re-use in future
1000+
// backward slices.
1001+
addRematValue(convertOp.getSrc(), convertOp.getType().getEncoding(),
1002+
convertOp.getResult());
1003+
}
9691004
}
9701005
}
9711006

@@ -976,6 +1011,12 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast() {
9761011
[&](ConvertLayoutOp convertOp) { convertOps.push_back(convertOp); });
9771012
for (ConvertLayoutOp convertOp : convertOps) {
9781013
hoistConvertOnTopOfExtOrBroadcast(convertOp);
1014+
if (!opToDelete.contains(convertOp)) {
1015+
// If the conversion didn't get removed, consider it for re-use in future
1016+
// backward slices.
1017+
addRematValue(convertOp.getSrc(), convertOp.getType().getEncoding(),
1018+
convertOp.getResult());
1019+
}
9791020
}
9801021
}
9811022

@@ -988,14 +1029,14 @@ void LayoutRematerialization::backwardRematerialization(
9881029
// careful with the heuristics for both correctness and perf
9891030
if (isa<DotOperandEncodingAttr, LinearEncodingAttr>(targetType.getEncoding()))
9901031
return;
991-
Value oldV = convertOp->getOperand(0);
1032+
Value oldV = convertOp.getSrc();
9921033
LDBG("check backward remat with source " << oldV << " encoding "
9931034
<< targetType.getEncoding());
9941035
// Check to see if there are existing remat'ed values for the pair of oldValue
995-
// and encoding.
996-
if (hasRematValue(oldV, targetType.getEncoding())) {
1036+
// and encoding. Make sure it dominates the current conversion.
1037+
Value newV = getRematValue(oldV, targetType.getEncoding());
1038+
if (newV && domInfo.properlyDominates(newV, convertOp)) {
9971039
// Replace it with the remat'ed value.
998-
Value newV = getRematValue(oldV, targetType.getEncoding());
9991040
convertOp.replaceAllUsesWith(newV);
10001041
opToDelete.insert(convertOp);
10011042
LDBG("found remat'ed value" << newV);
@@ -1007,7 +1048,7 @@ void LayoutRematerialization::backwardRematerialization(
10071048
SetVector<Value> slice;
10081049
DenseMap<Value, Attribute> layout;
10091050
LogicalResult result = getRematerializableSlice(
1010-
convertOp.getSrc(), targetType.getEncoding(), slice, layout);
1051+
convertOp.getSrcMutable(), targetType.getEncoding(), slice, layout);
10111052
if (result.failed()) {
10121053
LDBG(" getRematerializableSlice failed");
10131054
return;
@@ -1050,9 +1091,9 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast(
10501091
// 1. Take a backward slice of all the tensor dependencies.
10511092
SetVector<Value> slice;
10521093
DenseMap<Value, Attribute> layout;
1053-
LogicalResult result =
1054-
getRematerializableSlice(convertOp.getSrc(), targetType.getEncoding(),
1055-
slice, layout, isExtOrBroadcastOp);
1094+
LogicalResult result = getRematerializableSlice(
1095+
convertOp.getSrcMutable(), targetType.getEncoding(), slice, layout,
1096+
isExtOrBroadcastOp);
10561097
if (result.failed())
10571098
return;
10581099

@@ -1070,7 +1111,7 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast(
10701111
if (!srcEncoding)
10711112
return;
10721113
LogicalResult result = getRematerializableSlice(
1073-
op->getOperand(0), srcEncoding, tempSlice, tempLayout);
1114+
op->getOpOperand(0), srcEncoding, tempSlice, tempLayout);
10741115
// If we can rematerialize the rest of the ext slice we can ignore this
10751116
// ext as it won't need a convert.
10761117
if (result.succeeded()) {

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 40 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -770,44 +770,60 @@ static bool isFreeConvert(Operation *op) {
770770
convertOp.getType());
771771
}
772772

773-
LogicalResult
774-
getConvertBackwardSlice(Value root, SetVector<Value> &slice,
775-
Attribute rootEncoding,
776-
DenseMap<Value, Attribute> &layout,
777-
std::function<bool(Operation *)> stopPropagation) {
778-
DenseSet<std::pair<Value, Attribute>> seen;
779-
SmallVector<std::pair<Value, Attribute>> queue;
780-
781-
auto enqueue = [&](Value operand, Attribute encoding) {
782-
auto x = std::make_pair(operand, encoding);
773+
LogicalResult getConvertBackwardSlice(
774+
OpOperand &root, SetVector<Value> &slice, Attribute rootEncoding,
775+
DenseMap<Value, Attribute> &layout,
776+
std::function<bool(Operation *)> stopPropagation,
777+
std::function<Value(OpOperand &, Attribute)> getExistingConversion) {
778+
DenseSet<std::pair<OpOperand *, Attribute>> seen;
779+
SmallVector<std::pair<OpOperand *, Attribute>> queue;
780+
781+
auto enqueue = [&](OpOperand &operand, Attribute encoding) {
782+
auto x = std::make_pair(&operand, encoding);
783783
if (!seen.insert(x).second) {
784784
return; // Already enqueued, skip
785785
}
786786
queue.push_back(x);
787787
};
788788
enqueue(root, rootEncoding);
789789

790+
auto updateLayout = [&](Value value, Attribute encoding) {
791+
assert((isa<RankedTensorType>(value.getType())));
792+
slice.insert(value);
793+
if (layout.find(value) != layout.end()) {
794+
if (layout[value] != encoding)
795+
return failure();
796+
}
797+
layout[value] = encoding;
798+
return success();
799+
};
800+
790801
while (!queue.empty()) {
791-
auto [currentValue, encoding] = queue.back();
802+
auto [currentValueUse, encoding] = queue.back();
803+
Value currentValue = currentValueUse->get();
792804
queue.pop_back();
793805
if (!isa<RankedTensorType>(currentValue.getType()))
794806
continue;
795807
// Skip propagating through for op results for now.
796808
// TODO: enable this based on needs.
797809
if (currentValue.getDefiningOp<scf::ForOp>())
798810
return failure();
799-
slice.insert(currentValue);
800-
if (layout.find(currentValue) != layout.end()) {
801-
if (layout[currentValue] != encoding)
811+
if (failed(updateLayout(currentValue, encoding)))
812+
return failure();
813+
814+
Value existing;
815+
if (getExistingConversion &&
816+
(existing = getExistingConversion(*currentValueUse, encoding))) {
817+
if (failed(updateLayout(existing, encoding)))
802818
return failure();
819+
currentValue = existing;
803820
}
804-
layout[currentValue] = encoding;
805821

806822
if (auto ifOp = currentValue.getDefiningOp<scf::IfOp>()) {
807823
unsigned argIdx = mlir::cast<OpResult>(currentValue).getResultNumber();
808824

809-
auto thenValue = ifOp.thenYield().getOperand(argIdx);
810-
auto elseValue = ifOp.elseYield().getOperand(argIdx);
825+
OpOperand &thenValue = ifOp.thenYield()->getOpOperand(argIdx);
826+
OpOperand &elseValue = ifOp.elseYield()->getOpOperand(argIdx);
811827

812828
enqueue(thenValue, encoding);
813829
enqueue(elseValue, encoding);
@@ -819,10 +835,11 @@ getConvertBackwardSlice(Value root, SetVector<Value> &slice,
819835
for (Value result : definingOp->getResults()) {
820836
if (result == currentValue || !isa<RankedTensorType>(result.getType()))
821837
continue;
822-
enqueue(result, encoding);
838+
if (failed(updateLayout(result, encoding)))
839+
return failure();
823840
}
824841
if (isFreeConvert(definingOp)) {
825-
enqueue(definingOp->getOperand(0), encoding);
842+
enqueue(definingOp->getOpOperand(0), encoding);
826843
continue;
827844
}
828845
if (canFoldIntoConversion(definingOp, encoding))
@@ -837,10 +854,10 @@ getConvertBackwardSlice(Value root, SetVector<Value> &slice,
837854
auto srcEncoding = inferSrcEncoding(gather, encoding);
838855
if (!srcEncoding)
839856
return failure();
840-
enqueue(gather.getIndices(), srcEncoding);
857+
enqueue(gather.getIndicesMutable(), srcEncoding);
841858
continue;
842859
}
843-
for (auto [i, operand] : llvm::enumerate(definingOp->getOperands())) {
860+
for (auto [i, operand] : llvm::enumerate(definingOp->getOpOperands())) {
844861
auto srcEncoding = inferSrcEncoding(definingOp, encoding);
845862
if (!srcEncoding)
846863
return failure();
@@ -853,9 +870,9 @@ getConvertBackwardSlice(Value root, SetVector<Value> &slice,
853870
Operation *parentOp = block->getParentOp();
854871
if (auto forOp = dyn_cast<scf::ForOp>(parentOp)) {
855872
OpOperand *initOperand = forOp.getTiedLoopInit(blockArg);
856-
Value yieldOperand = forOp.getBody()->getTerminator()->getOperand(
873+
OpOperand &yieldOperand = forOp.getBody()->getTerminator()->getOpOperand(
857874
blockArg.getArgNumber() - forOp.getNumInductionVars());
858-
enqueue(initOperand->get(), encoding);
875+
enqueue(*initOperand, encoding);
859876
enqueue(yieldOperand, encoding);
860877
continue;
861878
}

0 commit comments

Comments
 (0)