Skip to content

Commit 03e391f

Browse files
authored
Merge OpenAI Triton commit 6f0ae97 (#3863)
This PR change the Triton base from 711caa4 to 6f0ae97 (Apr 7). Pass rate: 90.91%->90.79% Please do not squash and merge this PR.
2 parents 87b1723 + 4d8979f commit 03e391f

File tree

12 files changed

+492
-28
lines changed

12 files changed

+492
-28
lines changed

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,27 @@ def TTNG_TMEMAllocOp : TTNG_Op<"tmem_alloc", [DeclareOpInterfaceMethods<MemoryEf
473473
let hasVerifier = 1;
474474
}
475475

476+
def TTNG_TMEMSubSliceOp : TTNG_Op<"tmem_subslice", [Pure]> {
477+
let summary = "Take a subslice of a tensor memory allocation";
478+
let description = [{
479+
This operation takes a subslice of a tensor memory allocation and returns a new descriptor
480+
containing the address and a view of the subslice.
481+
This is similar to ttg.memdesc_subview except the offset needs to be static and we can only
482+
slice alog the inner dimension of a 2D memdesc as this is the only one we can do for TMem.
483+
}];
484+
let arguments = (ins TTG_MemDescType:$src, I32Attr:$N);
485+
486+
let assemblyFormat = [{
487+
$src attr-dict `:` qualified(type($src)) `->` qualified(type($result))
488+
}];
489+
490+
let builders = [
491+
OpBuilder<(ins "Value":$alloc, "int":$offset, "int":$size)>,
492+
];
493+
let results = (outs TTG_MemDescType:$result);
494+
let hasVerifier = 1;
495+
}
496+
476497
def TTNG_TMEMCopyOp : TTNG_Op<"tmem_copy", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
477498
let summary = "Initiate an asynchronous copy operation from shared memory to the Tensor Memory.";
478499

include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ std::unique_ptr<Pass> createTritonNvidiaGPUPromoteLHSToTMemPass();
6060

6161
std::unique_ptr<Pass> createTritonNvidiaGPUOptimizeDescriptorEncodingPass();
6262

63+
std::unique_ptr<Pass> createTritonNvidiaGPUOptimizeTMemSubtilingPass();
64+
6365
/// Generate the code for registering passes.
6466
#define GEN_PASS_REGISTRATION
6567
#define GEN_PASS_DECL_TRITONNVIDIAGPULEGALIZETMALAYOUTS

include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,4 +130,16 @@ def TritonNvidiaGPUOptimizeDescriptorEncodingPass : Pass<"triton-nvidia-optimize
130130
"mlir::triton::TritonDialect"];
131131
}
132132

133+
def TritonNvidiaGPUOptimizeTMemSubtilingPass : Pass<"triton-nvidia-optimize-tmem-subtiling", "mlir::ModuleOp"> {
134+
let summary = "Optimize subtiling.";
135+
136+
let description = [{
137+
Optimize subtiling by trying to split tmem_load when user splits a tensor.
138+
}];
139+
140+
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
141+
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
142+
"mlir::triton::TritonDialect"];
143+
}
144+
133145
#endif

lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,46 @@ void TMEMCopyOp::getEffects(
485485
mlir::triton::gpu::SharedMemory::get());
486486
}
487487

488+
// -- TMEMSubSliceOp --
489+
LogicalResult TMEMSubSliceOp::verify() {
490+
auto srcTy = cast<triton::gpu::MemDescType>(getSrc().getType());
491+
auto encoding = dyn_cast<triton::nvidia_gpu::TensorMemoryEncodingAttr>(
492+
srcTy.getEncoding());
493+
if (!encoding)
494+
return emitOpError("The source must be a tensor memory buffer.");
495+
if (encoding.getBlockM() != 128)
496+
return emitOpError("The source must be a 128xN layout.");
497+
auto dstTy = cast<triton::gpu::MemDescType>(getResult().getType());
498+
auto dstEncoding = dyn_cast<triton::nvidia_gpu::TensorMemoryEncodingAttr>(
499+
dstTy.getEncoding());
500+
if (!dstEncoding)
501+
return emitOpError("The destination must be a tensor memory buffer.");
502+
if (dstEncoding.getBlockM() != encoding.getBlockM() ||
503+
dstEncoding.getCTASplitM() != encoding.getCTASplitM() ||
504+
dstEncoding.getCTASplitN() != encoding.getCTASplitN() ||
505+
dstEncoding.getUnpacked() != encoding.getUnpacked())
506+
return emitOpError("The destination must have the same block size and "
507+
"CTASplit size as the source.");
508+
return mlir::success();
509+
}
510+
511+
void TMEMSubSliceOp::build(OpBuilder &builder, OperationState &state,
512+
Value alloc, int offset, int size) {
513+
auto allocTy = cast<triton::gpu::MemDescType>(alloc.getType());
514+
SmallVector<int64_t> shape(allocTy.getShape());
515+
shape.back() = size;
516+
auto encoding =
517+
cast<triton::nvidia_gpu::TensorMemoryEncodingAttr>(allocTy.getEncoding());
518+
unsigned newBlockN = std::min<unsigned>(encoding.getBlockN(), size);
519+
auto newEncoding = triton::nvidia_gpu::TensorMemoryEncodingAttr::get(
520+
builder.getContext(), encoding.getBlockM(), newBlockN,
521+
encoding.getUnpacked(), encoding.getCTASplitM(), encoding.getCTASplitN());
522+
auto subsliceType = gpu::MemDescType::get(
523+
shape, allocTy.getElementType(), newEncoding, allocTy.getMemorySpace(),
524+
allocTy.getMutableMemory());
525+
build(builder, state, subsliceType, alloc, offset);
526+
}
527+
488528
} // namespace nvidia_gpu
489529
} // namespace triton
490530
} // namespace mlir

lib/Dialect/TritonNvidiaGPU/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ add_triton_library(TritonNvidiaGPUTransforms
22
FenceInsertion.cpp
33
MMALowering.cpp
44
OptimizeDescriptorEncoding.cpp
5+
OptimizeTMemSubtiling.cpp
56
PlanCTA.cpp
67
PromoteLHSToTMem.cpp
78
TensorMemoryAllocation.cpp
Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
#include "mlir/IR/TypeUtilities.h"
2+
#include "mlir/Pass/PassManager.h"
3+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
4+
#include "triton/Dialect/Triton/IR/Dialect.h"
5+
#include "triton/Dialect/Triton/IR/Types.h"
6+
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
7+
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
8+
#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h"
9+
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
10+
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
11+
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h"
12+
#include "triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h"
13+
14+
namespace {
15+
16+
using namespace mlir;
17+
18+
namespace ttng = triton::nvidia_gpu;
19+
namespace ttg = triton::gpu;
20+
namespace tt = triton;
21+
22+
#define GEN_PASS_CLASSES
23+
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc"
24+
25+
// If we don't know the effects of the op, we add all possible effects.
26+
static void addAllValuelessEffects(
27+
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
28+
effects.emplace_back(MemoryEffects::Effect::get<MemoryEffects::Read>());
29+
effects.emplace_back(MemoryEffects::Effect::get<MemoryEffects::Write>());
30+
effects.emplace_back(MemoryEffects::Effect::get<MemoryEffects::Allocate>());
31+
effects.emplace_back(MemoryEffects::Effect::get<MemoryEffects::Free>());
32+
}
33+
34+
static bool
35+
collectEffects(Operation *op,
36+
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
37+
// Collect effect instances the operation. Note that the implementation of
38+
// getEffects erases all effect instances that have the type other than the
39+
// template parameter so we collect them first in a local buffer and then
40+
// copy.
41+
if (auto iface = dyn_cast<MemoryEffectOpInterface>(op)) {
42+
SmallVector<MemoryEffects::EffectInstance> localEffects;
43+
iface.getEffects(localEffects);
44+
llvm::append_range(effects, localEffects);
45+
return true;
46+
}
47+
if (op->hasTrait<OpTrait::HasRecursiveMemoryEffects>()) {
48+
for (auto &region : op->getRegions()) {
49+
for (auto &block : region) {
50+
for (auto &innerOp : block)
51+
if (!collectEffects(&innerOp, effects))
52+
return false;
53+
}
54+
}
55+
return true;
56+
}
57+
58+
// We need to be conservative here in case the op doesn't have the interface
59+
// and assume it can have any possible effect.
60+
addAllValuelessEffects(effects);
61+
return false;
62+
}
63+
64+
// Sink tmem_loads as close to their use as possible to reduce register
65+
// pressure.
66+
static void sinkLoad(ttng::TMEMLoadOp load, Operation *cvt) {
67+
Operation *insertBefore = nullptr;
68+
Operation *next = cvt->getNextNode();
69+
while (next && !next->hasTrait<OpTrait::IsTerminator>()) {
70+
insertBefore = next;
71+
bool dep = false;
72+
for (auto operand : getNestedOperands(next)) {
73+
if (operand == cvt->getResult(0)) {
74+
dep = true;
75+
break;
76+
}
77+
}
78+
if (!isMemoryEffectFree(next)) {
79+
SmallVector<MemoryEffects::EffectInstance> effects;
80+
collectEffects(next, effects);
81+
for (auto effect : effects) {
82+
if (effect.getEffect() ==
83+
MemoryEffects::Effect::get<MemoryEffects::Write>() ||
84+
effect.getEffect() ==
85+
MemoryEffects::Effect::get<MemoryEffects::Allocate>()) {
86+
if (effect.getResource() ==
87+
mlir::SideEffects::DefaultResource::get() ||
88+
effect.getResource() ==
89+
mlir::triton::nvidia_gpu::TensorMemory::get()) {
90+
dep = true;
91+
break;
92+
}
93+
}
94+
}
95+
}
96+
if (dep)
97+
break;
98+
next = next->getNextNode();
99+
}
100+
if (insertBefore) {
101+
load->moveBefore(insertBefore);
102+
cvt->moveBefore(insertBefore);
103+
}
104+
}
105+
106+
// clang-format off
107+
// Converts:
108+
// %l = ttng.tmem_load %o : !ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x256xf32, #blocked>
109+
// %r = tt.reshape %l : tensor<128x256xf32, #blocked> -> tensor<128x2x128xf32, #blocked4>
110+
// %t = tt.trans %r {order = array<i32: 0, 2, 1>} : tensor<128x2x128xf32, #blocked4> -> tensor<128x128x2xf32, #blocked5>
111+
// %outLHS, %outRHS = tt.split %t : tensor<128x128x2xf32, #blocked5> -> tensor<128x128xf32, #blocked2>
112+
// To:
113+
// %o0 = ttng.tmem_subslice %o { N = 0 }: !ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
114+
// %outLHS = ttng.tmem_load %o0 : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked>
115+
// %o1 = ttng.tmem_subslice %o { N = 128 }: !ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
116+
// %outRHS = ttng.tmem_load %o1 : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked>
117+
// clang-format on
118+
// This will change the layout of the destination tensor to distribute each
119+
// slice across warps. It currently only supports simple cases where tmem can be
120+
// sliced easily. This could be extended if needed with more powerful slicing
121+
// support of tmem.
122+
class TMemSplitLoadPattern : public OpRewritePattern<tt::SplitOp> {
123+
public:
124+
using OpRewritePattern::OpRewritePattern;
125+
126+
LogicalResult matchAndRewrite(tt::SplitOp splitOp,
127+
PatternRewriter &rewriter) const override {
128+
auto src = splitOp.getSrc();
129+
// Skip convert layout ops.
130+
while (auto cvt = src.getDefiningOp<ttg::ConvertLayoutOp>()) {
131+
src = cvt.getSrc();
132+
}
133+
// Only support splitting N dimension on the outer most.
134+
auto transOp = src.getDefiningOp<tt::TransOp>();
135+
if (!transOp || transOp.getOrder() != ArrayRef<int>({0, 2, 1}))
136+
return failure();
137+
auto reshapeOp = transOp.getSrc().getDefiningOp<tt::ReshapeOp>();
138+
if (!reshapeOp)
139+
return failure();
140+
auto shape = reshapeOp.getResult().getType().getShape();
141+
if (shape[0] != reshapeOp.getSrc().getType().getShape()[0])
142+
return failure();
143+
auto tmemLoad = reshapeOp.getSrc().getDefiningOp<ttng::TMEMLoadOp>();
144+
if (!tmemLoad)
145+
return failure();
146+
// We found a tmem_load that is split on the N dimension. We can split it
147+
// into multiple tmem_loads.
148+
int mDim = getShapePerCTA(tmemLoad.getSrc().getType())[0];
149+
// TODO: enable other M cases. (the layout is a bit more complex).
150+
if (mDim != 128)
151+
return failure();
152+
int splitNSize = shape[2];
153+
if (splitNSize < 8)
154+
return failure();
155+
Value tmem = tmemLoad.getSrc();
156+
int numWarps = ttg::lookupNumWarps(tmemLoad);
157+
rewriter.setInsertionPoint(tmemLoad);
158+
// First slice.
159+
Value subSlice0 = rewriter.create<ttng::TMEMSubSliceOp>(
160+
tmemLoad.getLoc(), tmem, 0, splitNSize);
161+
Attribute distLayout = ttng::getTmemCompatibleLayout(
162+
mDim, splitNSize, splitOp.getOutLHS().getType(), numWarps);
163+
RankedTensorType newLoadType = RankedTensorType::get(
164+
splitOp.getOutLHS().getType().getShape(),
165+
splitOp.getOutLHS().getType().getElementType(), distLayout);
166+
auto load0 = rewriter.create<ttng::TMEMLoadOp>(tmemLoad.getLoc(),
167+
newLoadType, subSlice0);
168+
auto cvt0 = rewriter.create<ttg::ConvertLayoutOp>(
169+
tmemLoad.getLoc(), splitOp.getOutLHS().getType(), load0);
170+
// Second slice.
171+
Value subSlice1 = rewriter.create<ttng::TMEMSubSliceOp>(
172+
tmemLoad.getLoc(), tmem, splitNSize, splitNSize);
173+
auto load1 = rewriter.create<ttng::TMEMLoadOp>(tmemLoad.getLoc(),
174+
newLoadType, subSlice1);
175+
auto cvt1 = rewriter.create<ttg::ConvertLayoutOp>(
176+
tmemLoad.getLoc(), splitOp.getOutRHS().getType(), load1);
177+
rewriter.replaceOp(splitOp, {cvt0, cvt1});
178+
sinkLoad(load0, cvt0);
179+
sinkLoad(load1, cvt1);
180+
return success();
181+
}
182+
};
183+
184+
class TritonNvidiaGPUOptimizeTMemSubtilingPass
185+
: public TritonNvidiaGPUOptimizeTMemSubtilingPassBase<
186+
TritonNvidiaGPUOptimizeTMemSubtilingPass> {
187+
public:
188+
using BaseT = TritonNvidiaGPUOptimizeTMemSubtilingPassBase<
189+
TritonNvidiaGPUOptimizeTMemSubtilingPass>;
190+
using BaseT::BaseT;
191+
192+
void runOnOperation() override {
193+
MLIRContext *context = &getContext();
194+
ModuleOp m = getOperation();
195+
196+
mlir::RewritePatternSet patterns(context);
197+
patterns.add<TMemSplitLoadPattern>(context);
198+
if (failed(applyPatternsGreedily(m, std::move(patterns))))
199+
signalPassFailure();
200+
}
201+
};
202+
203+
} // namespace
204+
205+
std::unique_ptr<Pass> mlir::createTritonNvidiaGPUOptimizeTMemSubtilingPass() {
206+
return std::make_unique<TritonNvidiaGPUOptimizeTMemSubtilingPass>();
207+
}

python/test/unit/language/test_matmul.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def matmul_kernel( #
3535
stride_cm, stride_cn, #
3636
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #
3737
NUM_STAGES: tl.constexpr, SCALE_A: tl.constexpr = None, PRECISION: tl.constexpr = "ieee",
38-
A_TRANS: tl.constexpr = False):
38+
A_TRANS: tl.constexpr = False, EPILOGUE_SUBTILE: tl.constexpr = False):
3939
pid = tl.program_id(axis=0)
4040
num_pid_m = tl.cdiv(M, BLOCK_M)
4141
pid_m = pid % num_pid_m
@@ -63,10 +63,21 @@ def matmul_kernel( #
6363
accumulator = tl.dot(a, b, acc=accumulator, out_dtype=output_ptr.dtype.element_ty, input_precision=PRECISION)
6464
a_ptrs += BLOCK_K * stride_ak
6565
b_ptrs += BLOCK_K * stride_bk
66-
offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
67-
offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
68-
output_ptrs = output_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
69-
tl.store(output_ptrs, accumulator)
66+
if EPILOGUE_SUBTILE:
67+
acc = tl.reshape(accumulator, (BLOCK_M, 2, BLOCK_N // 2))
68+
acc = tl.permute(acc, (0, 2, 1))
69+
acc0, acc1 = tl.split(acc)
70+
offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
71+
offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N // 2)
72+
output_ptrs0 = output_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
73+
output_ptrs1 = output_ptrs0 + stride_cn * (BLOCK_N // 2)
74+
tl.store(output_ptrs0, acc0)
75+
tl.store(output_ptrs1, acc1)
76+
else:
77+
offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
78+
offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
79+
output_ptrs = output_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
80+
tl.store(output_ptrs, accumulator)
7081

7182

7283
def get_src_element_ty_size(dtype_str):
@@ -86,8 +97,9 @@ def get_src_element_ty_size(dtype_str):
8697
(512, 64, 32, 2), (64, 16, 16, 4)])
8798
@pytest.mark.parametrize("NUM_CTAS", [1, 2])
8899
@pytest.mark.parametrize("NUM_WARPS", [4, 8])
89-
def test_simple_matmul(dtype_src_str, dtype_dst_str, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, NUM_WARPS, NUM_CTAS,
90-
device):
100+
@pytest.mark.parametrize("EPILOGUE_SUBTILE", [True, False])
101+
def test_simple_matmul(dtype_src_str, dtype_dst_str, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, NUM_WARPS, NUM_CTAS, device,
102+
EPILOGUE_SUBTILE):
91103
if NUM_CTAS > 1 and (not is_cuda() or torch.cuda.get_device_capability()[0] < 9):
92104
pytest.xfail("Clusters requires nvidia compute capability >= 9")
93105
if is_hip() and ((BLOCK_K * BLOCK_M + BLOCK_K * BLOCK_N) * NUM_STAGES * get_src_element_ty_size(dtype_src_str)
@@ -105,6 +117,8 @@ def test_simple_matmul(dtype_src_str, dtype_dst_str, BLOCK_M, BLOCK_N, BLOCK_K,
105117
pytest.skip("FMA matmul not supported for multiple CTAs")
106118
if (BLOCK_M < 64 or (BLOCK_M == 64 and BLOCK_N == 16)) and NUM_CTAS > 1:
107119
pytest.skip("multi-CTAs is broken for mmav2")
120+
if EPILOGUE_SUBTILE and not is_xpu() and (is_hip() or NUM_CTAS > 1 or BLOCK_N >= 512):
121+
pytest.skip("creates convert layout too big to fit in smem")
108122
M, N, K = 1024, 512, 256
109123
torch.manual_seed(42)
110124
precision = "tf32" if dtype_src_str == "tensorfloat32" else "ieee"
@@ -125,7 +139,7 @@ def test_simple_matmul(dtype_src_str, dtype_dst_str, BLOCK_M, BLOCK_N, BLOCK_K,
125139
grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1)
126140
k = matmul_kernel[grid](a, b, output, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), output.stride(0),
127141
output.stride(1), BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES=NUM_STAGES, PRECISION=precision,
128-
num_warps=NUM_WARPS, num_ctas=NUM_CTAS)
142+
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, EPILOGUE_SUBTILE=EPILOGUE_SUBTILE)
129143
ref_out = torch.matmul(A, B).to(torch.float32)
130144
output = output.to(torch.float32)
131145
if dtype_src_str == "float32":

0 commit comments

Comments
 (0)