Skip to content

Commit 182fb7f

Browse files
Merge commit 'b155d8a8d47f391f43c0ad93d65104e3dbfa6e69'
2 parents 0e47c27 + b155d8a commit 182fb7f

File tree

24 files changed

+876
-303
lines changed

24 files changed

+876
-303
lines changed

.github/workflows/integration-tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,7 @@ jobs:
427427
pytest --capture=tee-sys -rfs python/tutorials/06-fused-attention.py
428428
pytest --capture=tee-sys -rfs third_party/amd/python/test/test_extract_slice.py
429429
cd python/test/unit
430-
pytest --capture=tee-sys -rfs -n 16 language runtime \
430+
pytest --capture=tee-sys -rfs -n 12 language runtime \
431431
--ignore=language/test_line_info.py \
432432
--ignore=test_debug.py
433433
# TODO: uncomment

.github/workflows/integration-tests.yml.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,7 @@ jobs:
414414
pytest --capture=tee-sys -rfs python/tutorials/06-fused-attention.py
415415
pytest --capture=tee-sys -rfs third_party/amd/python/test/test_extract_slice.py
416416
cd python/test/unit
417-
pytest --capture=tee-sys -rfs -n 16 language runtime \
417+
pytest --capture=tee-sys -rfs -n 12 language runtime \
418418
--ignore=language/test_line_info.py \
419419
--ignore=test_debug.py
420420
# TODO: uncomment

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -273,12 +273,25 @@ struct SharedMemoryObject {
273273
ArrayRef<Value> offsets)
274274
: base(base), baseElemType(baseElemType),
275275
strides(strides.begin(), strides.end()),
276-
offsets(offsets.begin(), offsets.end()) {}
276+
offsets(offsets.begin(), offsets.end()) {
277+
assert(strides.size() == offsets.size());
278+
}
277279

278280
SharedMemoryObject(Value base, Type baseElemType, ArrayRef<int64_t> shape,
279-
ArrayRef<unsigned> order, Location loc,
281+
triton::gpu::SharedEncodingAttr layout, Location loc,
280282
RewriterBase &rewriter)
281283
: base(base), baseElemType(baseElemType) {
284+
SmallVector<unsigned> order(shape.size());
285+
// Default minor-to-major order
286+
std::iota(order.rbegin(), order.rend(), 0);
287+
if (layout) {
288+
auto layoutOrder = convertType<int>(layout.getOrder());
289+
int rankDiff = layoutOrder.size() - shape.size();
290+
auto minRank = std::min(shape.size(), layoutOrder.size());
291+
for (size_t i = 0; i < minRank; ++i)
292+
order[i] = layoutOrder[i] - rankDiff;
293+
}
294+
assert(isPermutationOfIota(order) && "Invalid order");
282295
strides = getStridesFromShapeAndOrder(shape, order, loc, rewriter);
283296
offsets.append(order.size(), i32_val(0));
284297
}
@@ -304,14 +317,14 @@ struct SharedMemoryObject {
304317
return types;
305318
}
306319

307-
Value getCSwizzleOffset(int order) const {
308-
assert(order >= 0 && order < strides.size());
309-
return offsets[order];
320+
Value getCSwizzleOffset(int dim) const {
321+
assert(dim >= 0 && dim < strides.size());
322+
return offsets[dim];
310323
}
311324

312-
Value getBaseBeforeSlice(int order, Location loc,
325+
Value getBaseBeforeSlice(int dim, Location loc,
313326
RewriterBase &rewriter) const {
314-
Value cSwizzleOffset = getCSwizzleOffset(order);
327+
Value cSwizzleOffset = getCSwizzleOffset(dim);
315328
Value offset = sub(i32_val(0), cSwizzleOffset);
316329
Type type = base.getType();
317330
return gep(type, baseElemType, base, offset);

include/triton/Dialect/Triton/IR/Utility.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ template <typename T> bool isPermutationOfIota(ArrayRef<T> vals) {
148148
return isIota(sorted);
149149
}
150150

151-
template <typename VecT> bool IsPermutationOfIota(const VecT &vec) {
151+
template <typename VecT> bool isPermutationOfIota(const VecT &vec) {
152152
return isPermutationOfIota(ArrayRef(vec));
153153
}
154154

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ namespace {
1818

1919
using ::mlir::LLVM::getMultiDimOffset;
2020
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
21-
using ::mlir::LLVM::getStridesFromShapeAndOrder;
2221
using ::mlir::LLVM::getWrappedMultiDimOffset;
2322
using ::mlir::LLVM::linearize;
2423

@@ -380,8 +379,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
380379
return !useLegacyMMAConversion;
381380
}
382381
if (auto dotOperand = dyn_cast<DotOperandEncodingAttr>(layout)) {
383-
if (isa<NvidiaMmaEncodingAttr, AMDMfmaEncodingAttr>(
384-
dotOperand.getParent())) {
382+
if (isa<MmaEncodingTrait>(dotOperand.getParent())) {
385383
return !useLegacyMMAConversion;
386384
}
387385
return false;

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
using ValueTable = std::map<std::pair<int, int>, Value>;
55
using ::mlir::LLVM::delinearize;
66
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
7-
using ::mlir::LLVM::getStridesFromShapeAndOrder;
87
using ::mlir::LLVM::linearize;
98
using ::mlir::triton::gpu::DotOperandEncodingAttr;
109
using ::mlir::triton::gpu::getContigPerThread;

lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -78,23 +78,11 @@ struct LocalAllocOpConversion
7878
auto typeConverter = getTypeConverter();
7979
auto sharedLayout =
8080
cast<triton::gpu::SharedEncodingAttr>(resultTy.getEncoding());
81-
auto order = sharedLayout.getOrder();
82-
// Workaround for 3D tensors
83-
// TODO: we need to modify the pipeline pass to give a proper shared
84-
// encoding to 3D tensors
85-
SmallVector<unsigned> newOrder;
86-
if (resultTy.getShape().size() != order.size()) {
87-
for (auto i = 0; i < order.size(); ++i)
88-
newOrder.push_back(order[i] + 1);
89-
newOrder.push_back(0);
90-
} else {
91-
newOrder = SmallVector<unsigned>(order.begin(), order.end());
92-
}
9381

9482
auto llvmElemTy = typeConverter->convertType(resultTy.getElementType());
9583
auto shapePerCTA = getShapePerCTA(sharedLayout, resultTy.getShape());
9684
auto smemObj = SharedMemoryObject(smemBase, llvmElemTy, shapePerCTA,
97-
newOrder, loc, rewriter);
85+
sharedLayout, loc, rewriter);
9886
// If there is an initial tensor, store it into the shared memory.
9987
if (op.getSrc()) {
10088
lowerDistributedToShared(loc, op.getSrc(), op.getResult(),
@@ -159,7 +147,7 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
159147
srcTy.getShape()[1] >= 4 * kWidth & dstTy.getRank() <= 2;
160148
return !canUseLdmatrix;
161149
}
162-
if (isa<AMDMfmaEncodingAttr>(dot.getParent()))
150+
if (isa<AMDMfmaEncodingAttr, AMDWmmaEncodingAttr>(dot.getParent()))
163151
return true;
164152
}
165153
return false;

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,9 +189,9 @@ Value getSmemVecAddr(RankedTensorType registerTy,
189189
dyn_cast<triton::gpu::SharedEncodingAttr>(sharedTy.getEncoding());
190190

191191
auto smemBase = smemObj.getBase();
192-
auto sharedOrder = triton::gpu::getOrder(sharedTy.getEncoding());
193192
auto smemOffsets = smemObj.getOffsets();
194193
auto smemStrides = smemObj.getStrides();
194+
auto smemOrder = sharedEnc.getOrder();
195195
Value smemOffset;
196196
// When loading or storing to shared memory, we consider two cases for
197197
// performance reasons:
@@ -239,9 +239,11 @@ Value getSmemVecAddr(RankedTensorType registerTy,
239239
// Reorder strides according to `order`. This way they match the
240240
// multi-dimensional offsets in regToSharedLayout.
241241
smemOffset = dot(rewriter, loc, smemOffsets,
242-
applyPermutation(smemStrides, sharedOrder));
242+
applyPermutation(smemStrides, smemOrder));
243243
} else { // Case 2 -> rank-reduced swizzling
244244
assert(rank >= 2 && "Swizzling only applies to tensors with rank >= 2");
245+
assert(!sharedEnc.getHasLeadingOffset() &&
246+
"Leading offsets are not supported for sliced tensors");
245247
// We define both tensor offsets and shared memory offsets:
246248
//
247249
// - Tensor offsets: Relative offsets within a given tensor.
@@ -572,6 +574,7 @@ SmallVector<Value> getStridesFromShapeAndOrder(ArrayRef<int64_t> shape,
572574
ArrayRef<unsigned> order,
573575
Location loc,
574576
RewriterBase &rewriter) {
577+
assert(order.size() == shape.size() && "shape and order must have same size");
575578
auto rank = shape.size();
576579
SmallVector<Value> strides(rank);
577580
int64_t stride = 1;

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1152,11 +1152,9 @@ SmallVector<unsigned> DotOperandEncodingAttr::getWarpsPerCTA() const {
11521152
}
11531153
SmallVector<unsigned> DotOperandEncodingAttr::getWarpOrder() const {
11541154
// FIXME(Lezcano): Preexisting. Do we want to have this path at all?
1155-
if (mlir::isa<AMDMfmaEncodingAttr>(getParent())) {
1155+
if (mlir::isa<AMDMfmaEncodingAttr, AMDWmmaEncodingAttr>(getParent())) {
11561156
return ::getWarpOrder(getParent());
11571157
}
1158-
// It's quite weird to talk about warp order when that the warps
1159-
// are broadcasted along the K dimension
11601158
llvm::report_fatal_error("DotOperandEncoding::getWarpOrder not implemented");
11611159
return {};
11621160
}
@@ -1201,9 +1199,9 @@ LogicalResult DotOperandEncodingAttr::verify(
12011199

12021200
if (auto parentAttr = mlir::dyn_cast<AMDWmmaEncodingAttr>(parent)) {
12031201
if (kWidth != 16 && parentAttr.getVersion() == 1 ||
1204-
kWidth != 8 && parentAttr.getVersion() == 2)
1202+
kWidth != 8 && kWidth != 16 && parentAttr.getVersion() == 2)
12051203
return emitError() << "ttg.dot_op kWidth parameter must be 16 for "
1206-
"gfx11 and 8 for gfx12";
1204+
"gfx11 and 8/16 for gfx12";
12071205
return success();
12081206
}
12091207

0 commit comments

Comments
 (0)