Skip to content

Commit 9829ce8

Browse files
authored
[BACKEND] Remove workarounds for 3d shapes of SharedMemoryObject (#5425)
1 parent 0e417ef commit 9829ce8

File tree

9 files changed

+37
-49
lines changed

9 files changed

+37
-49
lines changed

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -273,12 +273,25 @@ struct SharedMemoryObject {
273273
ArrayRef<Value> offsets)
274274
: base(base), baseElemType(baseElemType),
275275
strides(strides.begin(), strides.end()),
276-
offsets(offsets.begin(), offsets.end()) {}
276+
offsets(offsets.begin(), offsets.end()) {
277+
assert(strides.size() == offsets.size());
278+
}
277279

278280
SharedMemoryObject(Value base, Type baseElemType, ArrayRef<int64_t> shape,
279-
ArrayRef<unsigned> order, Location loc,
281+
triton::gpu::SharedEncodingAttr layout, Location loc,
280282
RewriterBase &rewriter)
281283
: base(base), baseElemType(baseElemType) {
284+
SmallVector<unsigned> order(shape.size());
285+
// Default minor-to-major order
286+
std::iota(order.rbegin(), order.rend(), 0);
287+
if (layout) {
288+
auto layoutOrder = convertType<int>(layout.getOrder());
289+
int rankDiff = layoutOrder.size() - shape.size();
290+
auto minRank = std::min(shape.size(), layoutOrder.size());
291+
for (size_t i = 0; i < minRank; ++i)
292+
order[i] = layoutOrder[i] - rankDiff;
293+
}
294+
assert(isPermutationOfIota(order) && "Invalid order");
282295
strides = getStridesFromShapeAndOrder(shape, order, loc, rewriter);
283296
offsets.append(order.size(), i32_val(0));
284297
}
@@ -304,14 +317,14 @@ struct SharedMemoryObject {
304317
return types;
305318
}
306319

307-
Value getCSwizzleOffset(int order) const {
308-
assert(order >= 0 && order < strides.size());
309-
return offsets[order];
320+
Value getCSwizzleOffset(int dim) const {
321+
assert(dim >= 0 && dim < strides.size());
322+
return offsets[dim];
310323
}
311324

312-
Value getBaseBeforeSlice(int order, Location loc,
325+
Value getBaseBeforeSlice(int dim, Location loc,
313326
RewriterBase &rewriter) const {
314-
Value cSwizzleOffset = getCSwizzleOffset(order);
327+
Value cSwizzleOffset = getCSwizzleOffset(dim);
315328
Value offset = sub(i32_val(0), cSwizzleOffset);
316329
Type type = base.getType();
317330
return gep(type, baseElemType, base, offset);

include/triton/Dialect/Triton/IR/Utility.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ template <typename T> bool isPermutationOfIota(ArrayRef<T> vals) {
148148
return isIota(sorted);
149149
}
150150

151-
template <typename VecT> bool IsPermutationOfIota(const VecT &vec) {
151+
template <typename VecT> bool isPermutationOfIota(const VecT &vec) {
152152
return isPermutationOfIota(ArrayRef(vec));
153153
}
154154

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ namespace {
1818

1919
using ::mlir::LLVM::getMultiDimOffset;
2020
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
21-
using ::mlir::LLVM::getStridesFromShapeAndOrder;
2221
using ::mlir::LLVM::getWrappedMultiDimOffset;
2322
using ::mlir::LLVM::linearize;
2423

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
using ValueTable = std::map<std::pair<int, int>, Value>;
66
using ::mlir::LLVM::delinearize;
77
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
8-
using ::mlir::LLVM::getStridesFromShapeAndOrder;
98
using ::mlir::LLVM::linearize;
109
using ::mlir::triton::gpu::DotOperandEncodingAttr;
1110
using ::mlir::triton::gpu::expandMatrixOrderWithBatch;

lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -78,23 +78,11 @@ struct LocalAllocOpConversion
7878
auto typeConverter = getTypeConverter();
7979
auto sharedLayout =
8080
cast<triton::gpu::SharedEncodingAttr>(resultTy.getEncoding());
81-
auto order = sharedLayout.getOrder();
82-
// Workaround for 3D tensors
83-
// TODO: we need to modify the pipeline pass to give a proper shared
84-
// encoding to 3D tensors
85-
SmallVector<unsigned> newOrder;
86-
if (resultTy.getShape().size() != order.size()) {
87-
for (auto i = 0; i < order.size(); ++i)
88-
newOrder.push_back(order[i] + 1);
89-
newOrder.push_back(0);
90-
} else {
91-
newOrder = SmallVector<unsigned>(order.begin(), order.end());
92-
}
9381

9482
auto llvmElemTy = typeConverter->convertType(resultTy.getElementType());
9583
auto shapePerCTA = getShapePerCTA(sharedLayout, resultTy.getShape());
9684
auto smemObj = SharedMemoryObject(smemBase, llvmElemTy, shapePerCTA,
97-
newOrder, loc, rewriter);
85+
sharedLayout, loc, rewriter);
9886
// If there is an initial tensor, store it into the shared memory.
9987
if (op.getSrc()) {
10088
lowerDistributedToShared(loc, op.getSrc(), op.getResult(),

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,9 +189,9 @@ Value getSmemVecAddr(RankedTensorType registerTy,
189189
dyn_cast<triton::gpu::SharedEncodingAttr>(sharedTy.getEncoding());
190190

191191
auto smemBase = smemObj.getBase();
192-
auto sharedOrder = triton::gpu::getOrder(sharedTy.getEncoding());
193192
auto smemOffsets = smemObj.getOffsets();
194193
auto smemStrides = smemObj.getStrides();
194+
auto smemOrder = sharedEnc.getOrder();
195195
Value smemOffset;
196196
// When loading or storing to shared memory, we consider two cases for
197197
// performance reasons:
@@ -239,9 +239,11 @@ Value getSmemVecAddr(RankedTensorType registerTy,
239239
// Reorder strides according to `order`. This way they match the
240240
// multi-dimensional offsets in regToSharedLayout.
241241
smemOffset = dot(rewriter, loc, smemOffsets,
242-
applyPermutation(smemStrides, sharedOrder));
242+
applyPermutation(smemStrides, smemOrder));
243243
} else { // Case 2 -> rank-reduced swizzling
244244
assert(rank >= 2 && "Swizzling only applies to tensors with rank >= 2");
245+
assert(!sharedEnc.getHasLeadingOffset() &&
246+
"Leading offsets are not supported for sliced tensors");
245247
// We define both tensor offsets and shared memory offsets:
246248
//
247249
// - Tensor offsets: Relative offsets within a given tensor.
@@ -572,6 +574,7 @@ SmallVector<Value> getStridesFromShapeAndOrder(ArrayRef<int64_t> shape,
572574
ArrayRef<unsigned> order,
573575
Location loc,
574576
RewriterBase &rewriter) {
577+
assert(order.size() == shape.size() && "shape and order must have same size");
575578
auto rank = shape.size();
576579
SmallVector<Value> strides(rank);
577580
int64_t stride = 1;

third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandHelper.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,10 +122,11 @@ llvm::SmallVector<Value> computeOffsetsAType(
122122
SharedEncodingAttr srcLayout, unsigned nonKDim, unsigned kDim) {
123123
SmallVector<Value> strides = smemObj.getStrides();
124124
SmallVector<Value> offsets = smemObj.getOffsets();
125+
auto order = srcLayout.getOrder();
125126
auto rank = offsets.size();
126127

127128
int vectorSize = 1;
128-
if (srcLayout.getOrder()[0] == rank - 1) {
129+
if (order[0] == rank - 1) {
129130
if (isSwizzled(srcLayout))
130131
vectorSize = std::min(static_cast<int>(srcLayout.getVec()), numOfElems);
131132
else
@@ -136,7 +137,6 @@ llvm::SmallVector<Value> computeOffsetsAType(
136137
reps, offsets, vectorSize, nonKDim, kDim);
137138
const auto numBlocks = reps[reps.size() - 2];
138139
const auto blockSize = mapping.size();
139-
auto order = srcLayout.getOrder();
140140
llvm::SmallVector<Value> aOffsets(blockSize * numBlocks);
141141

142142
if (!isSwizzlePatternFitsIntoBlock(srcLayout, 0, reps, elemsPerInstr,
@@ -190,13 +190,14 @@ llvm::SmallVector<Value> computeOffsetsBType(
190190
// transposed operand A layout
191191
// this unifies axis order, so non-K dim is 0, k dim is 1
192192
auto rank = smemObj.getOffsets().size();
193+
auto order = srcLayout.getOrder();
193194
SmallVector<int64_t> tElemsPerInstr{elemsPerInstr[1], elemsPerInstr[0]};
194195
SmallVector<int64_t> tReps = transposeSpatialDims(reps);
195196
SmallVector<Value> tOffsets = transposeSpatialDims(smemObj.getOffsets());
196197
SmallVector<Value> tStrides = transposeSpatialDims(smemObj.getStrides());
197198

198199
int vectorSize = 1;
199-
if (srcLayout.getOrder()[0] == rank - 2) {
200+
if (order[0] == rank - 2) {
200201
if (isSwizzled(srcLayout))
201202
vectorSize = std::min(static_cast<int>(srcLayout.getVec()), numOfElems);
202203
else
@@ -207,7 +208,6 @@ llvm::SmallVector<Value> computeOffsetsBType(
207208
tReps, tOffsets, vectorSize, nonKDim, kDim);
208209
const auto numBlocks = tReps[tReps.size() - 2];
209210
const auto blockSize = mapping.size();
210-
auto order = srcLayout.getOrder();
211211
llvm::SmallVector<Value> bOffsets(blockSize * numBlocks);
212212

213213
if (!isSwizzlePatternFitsIntoBlock(srcLayout, 0, reps, elemsPerInstr,

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -423,20 +423,9 @@ struct LocalAllocOpConversion
423423
}
424424

425425
auto resultTy = cast<MemDescType>(op.getType());
426-
// Workaround for 3D tensors
427-
// TODO: we need to modify the pipeline pass to give a proper shared
428-
// encoding to 3D tensors
429-
SmallVector<unsigned> newOrder;
430-
if (resultTy.getShape().size() != order.size()) {
431-
for (auto i = 0; i < order.size(); ++i)
432-
newOrder.push_back(order[i] + 1);
433-
newOrder.push_back(0);
434-
} else {
435-
newOrder = SmallVector<unsigned>(order.begin(), order.end());
436-
}
437426
auto shapePerCTA = getShapePerCTA(sharedLayout, resultTy.getShape());
438427
auto smemObj = SharedMemoryObject(smemBase, llvmElemTy, shapePerCTA,
439-
newOrder, loc, rewriter);
428+
sharedLayout, loc, rewriter);
440429
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
441430
rewriter.replaceOp(op, retVal);
442431
return success();

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,8 @@ using namespace mlir;
77
using ValueTable = std::map<std::array<int, 3>, Value>;
88
using ::mlir::LLVM::delinearize;
99
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
10-
using ::mlir::LLVM::getStridesFromShapeAndOrder;
1110
using ::mlir::triton::gpu::DotOperandEncodingAttr;
1211
using ::mlir::triton::gpu::getContigPerThread;
13-
using ::mlir::triton::gpu::getOrder;
1412
using ::mlir::triton::gpu::getShapePerCTA;
1513
using ::mlir::triton::gpu::getSizePerThread;
1614
using ::mlir::triton::gpu::getTotalElemsPerThread;
@@ -608,12 +606,11 @@ getLoadMatrixFn(MemDescType descTy, const SharedMemoryObject &smemObj,
608606
std::max<int>(shapePerCTA[2] / mmaLayout.getWarpsPerCTA()[2], 8);
609607
// (a, b) is the coordinate.
610608
auto load = [=, &rewriter, &vals](int batch, int a, int b) {
611-
MMA16816SmemLoader loader(nPerWarp, warpsPerTile, sharedLayout.getOrder(),
612-
mmaLayout.getWarpsPerCTA(), kOrder, kWidth,
613-
smemObj.strides, shapePerCTA /*tileShape*/,
614-
instrShape, matShape, multiDimWarpId, perPhase,
615-
maxPhase, elemBytes, mmaElemBytes, isHopper,
616-
rewriter, typeConverter, loc);
609+
MMA16816SmemLoader loader(
610+
nPerWarp, warpsPerTile, order, mmaLayout.getWarpsPerCTA(), kOrder,
611+
kWidth, smemObj.strides, shapePerCTA /*tileShape*/, instrShape,
612+
matShape, multiDimWarpId, perPhase, maxPhase, elemBytes, mmaElemBytes,
613+
isHopper, rewriter, typeConverter, loc);
617614
// Offset of a slice within the original tensor in shared memory
618615
Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]);
619616
SmallVector<Value> offs = loader.computeOffsets(lane, cSwizzleOffset);

0 commit comments

Comments
 (0)