Skip to content

Commit d0b16ab

Browse files
authored
[BACKEND] Initial multicta support in gluon (#8587)
We support num-cta all the way up to 16. We simplify the indexing associated with it. We add a comprehensive test for both Hopper and Blackwell and fix a few latent issues we had when supporting `num_cta > 2`.
1 parent e65ac45 commit d0b16ab

File tree

13 files changed

+180
-187
lines changed

13 files changed

+180
-187
lines changed

include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,7 @@ namespace mlir {
3030
namespace triton {
3131
namespace nvidia_gpu {
3232

33-
// Used by Triton runtime
34-
struct ClusterInfo {
35-
ClusterInfo() : clusterDimX(1), clusterDimY(1), clusterDimZ(1) {}
36-
int clusterDimX;
37-
int clusterDimY;
38-
int clusterDimZ;
39-
};
40-
41-
std::unique_ptr<Pass> createTritonNvidiaGPUPlanCTAPass(
42-
mlir::triton::nvidia_gpu::ClusterInfo *clusterInfo = nullptr);
33+
std::unique_ptr<Pass> createTritonNvidiaGPUPlanCTAPass();
4334

4435
#define GEN_PASS_DECL
4536
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc"

lib/Dialect/TritonNvidiaGPU/IR/Dialect.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,8 @@ static std::optional<LinearLayout> getDistributedLayoutForTmemLdSt(
144144
// Get CTALayout without broadcasting to divide the ll
145145
// as the TMEM layout does not reflect CTA broadcasting
146146
auto splitNum = ctaLayout->getCTASplitNum();
147-
auto ctaBlockSplit =
148-
CTALayoutAttr::get(ctx, splitNum, splitNum, ctaLayout->getCTAOrder());
147+
// The cta order in TMEM is always [0, 1]
148+
auto ctaBlockSplit = CTALayoutAttr::get(ctx, splitNum, splitNum, {0, 1});
149149
auto ctaBlockSplitLL = gpu::makeCgaLayout(ctaBlockSplit);
150150
assert(ctaBlockSplitLL.getNumOutDims() == ll.getNumOutDims());
151151
// rename block into col

lib/Dialect/TritonNvidiaGPU/IR/TensorMemoryUtils.cpp

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -299,19 +299,9 @@ computeTMemLdStEncodingInfo(RankedTensorType regTy, MemDescType memTy,
299299
cvt = LinearLayout(bases, cvt.getOutDims(),
300300
/*isSurjective=*/cvt.isSurjective());
301301

302-
// tmemBase already encodes CTA/block offsets so we just remove them from the
303-
// cvt
304-
auto kBlock = StringAttr::get(ctx, "block");
305-
auto kCol = StringAttr::get(ctx, "col");
306-
auto nCTAs = cvt.getInDimSize(kBlock);
307-
auto maybeQuot =
308-
divideRight(cvt, LinearLayout::identity1D(nCTAs, kBlock, kCol));
309-
assert(maybeQuot.has_value());
310-
auto quot = maybeQuot->unsqueezeIn(kBlock);
311-
312302
bool isScales = isa<TensorMemoryScalesEncodingAttr>(memTy.getEncoding());
313303
int bitwidth = memTy.getElementTypeBitWidth();
314-
return lowerTMemLdSt(quot, maxnreg, bitwidth, isScales, emitError);
304+
return lowerTMemLdSt(cvt, maxnreg, bitwidth, isScales, emitError);
315305
}
316306

317307
} // namespace mlir::triton::nvidia_gpu

lib/Dialect/TritonNvidiaGPU/Transforms/PlanCTA.cpp

Lines changed: 15 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,7 @@ replaceCTALayout(ttg::DistributedEncodingTrait layout,
8484

8585
class CTAPlanner {
8686
public:
87-
CTAPlanner(ClusterInfo *clusterInfo_);
88-
~CTAPlanner();
87+
CTAPlanner();
8988

9089
void run(triton::FuncOp &funcOp);
9190

@@ -95,7 +94,6 @@ class CTAPlanner {
9594
bool isBackward(CastOp cast) const;
9695
bool isForward(CastOp cast) const;
9796

98-
void setTiling(llvm::ArrayRef<unsigned> CTAsPerCGA);
9997
bool processDot(triton::FuncOp &funcOp);
10098
bool processReduce(triton::FuncOp &funcOp);
10199
void processStoreLikeOps(triton::FuncOp &funcOp);
@@ -149,39 +147,17 @@ class CTAPlanner {
149147
bool processMultiUsersBackward(Value input, CastOp cast);
150148
bool processMultiUsersForward(Value output, CastOp cast);
151149

152-
// This flag indicates whether clusterInfo needs to be deleted in the
153-
// destructor of CTAPlanner. The flag `ownInfo` is set to false when a
154-
// non-null pointer to clusterInfo is passed to the constructor of CTAPlanner.
155-
// Otherwise, a self-managed ClusterInfo will be created and the ownInfo will
156-
// be set to true.
157-
bool ownInfo;
158-
ClusterInfo *clusterInfo;
159-
bool tiled;
150+
void markTiled();
151+
160152
unsigned step;
161153
unsigned stepUnchanged;
154+
bool tiled;
162155
std::queue<CastOp> queue;
163156
};
164157

165-
CTAPlanner::CTAPlanner(ClusterInfo *clusterInfo_)
166-
: ownInfo(false), clusterInfo(clusterInfo_), tiled(false), step(0),
167-
stepUnchanged(0) {
168-
if (clusterInfo == nullptr) {
169-
clusterInfo = new ClusterInfo();
170-
ownInfo = true;
171-
}
172-
}
173-
174-
CTAPlanner::~CTAPlanner() {
175-
if (ownInfo) {
176-
delete clusterInfo;
177-
// Actually not necessary but safer
178-
ownInfo = false;
179-
clusterInfo = nullptr;
180-
}
181-
}
158+
CTAPlanner::CTAPlanner() : step(0), stepUnchanged(0), tiled(false) {}
182159

183160
void CTAPlanner::run(triton::FuncOp &funcOp) {
184-
assert(!tiled && "Please create a new CTAPlanner");
185161
static const unsigned maxSteps = 10000;
186162

187163
auto nextStep = [&]() {
@@ -232,29 +208,9 @@ bool CTAPlanner::isForward(CastOp cast) const {
232208
return cast->getAttrOfType<StringAttr>("direction") == "forward";
233209
}
234210

235-
void CTAPlanner::setTiling(llvm::ArrayRef<unsigned> CTAsPerCGA) {
236-
assert(!tiled && "CTA tiling is already determinted");
237-
assert(clusterInfo && "ClusterInfo pointer is null");
211+
void CTAPlanner::markTiled() {
212+
assert(!tiled && "CTA tiling is already determined");
238213
tiled = true;
239-
unsigned numCTAs = 1;
240-
for (unsigned cta : CTAsPerCGA)
241-
numCTAs *= cta;
242-
if (numCTAs == 2) {
243-
// For 2 CTAs always use 2x1x1.
244-
// TODO: can we always serialize the CTAs on X dimension?
245-
clusterInfo->clusterDimX = 2;
246-
return;
247-
}
248-
249-
if (CTAsPerCGA.size() > 0)
250-
clusterInfo->clusterDimX = CTAsPerCGA[0];
251-
if (CTAsPerCGA.size() > 1)
252-
clusterInfo->clusterDimY = CTAsPerCGA[1];
253-
if (CTAsPerCGA.size() > 2)
254-
clusterInfo->clusterDimZ = CTAsPerCGA[2];
255-
for (auto i = 3; i < CTAsPerCGA.size(); ++i)
256-
if (CTAsPerCGA[i] != 1)
257-
llvm::report_fatal_error("tiling > 3 dims is not implemented");
258214
}
259215

260216
bool CTAPlanner::processDot(triton::FuncOp &funcOp) {
@@ -297,7 +253,7 @@ bool CTAPlanner::processDot(triton::FuncOp &funcOp) {
297253
unsigned splitM, splitN;
298254
std::tie(splitM, splitN) = getCTATiling(M, N, K, ttg::getNumCTAs(dLayout));
299255
// FIXME: Should consider IR with more than one DotOps
300-
setTiling({splitM, splitN, 1});
256+
markTiled();
301257

302258
OpBuilder builder(dot);
303259
auto numThreads = ttg::lookupThreadsPerWarp(builder);
@@ -373,7 +329,7 @@ bool CTAPlanner::processReduce(triton::FuncOp &funcOp) {
373329
auto CTALayout =
374330
ttg::CTALayoutAttr::get(context, CTAsPerCGA, CTASplitNum, CTAOrder);
375331
if (!tiled)
376-
setTiling(CTALayout.getCTAsPerCGA());
332+
markTiled();
377333
auto newSrcLayout =
378334
replaceCTALayout(cast<ttg::DistributedEncodingTrait>(srcLayout),
379335
srcShape, numWarps, CTALayout);
@@ -389,7 +345,7 @@ bool CTAPlanner::processReduce(triton::FuncOp &funcOp) {
389345
}
390346

391347
void CTAPlanner::processStoreLikeOps(triton::FuncOp &funcOp) {
392-
assert(!tiled && "CTA tiling is already determinted");
348+
assert(!tiled && "CTA tiling is already determined");
393349

394350
llvm::SmallVector<Operation *> stores;
395351
funcOp.walk([&](Operation *op) {
@@ -412,7 +368,7 @@ void CTAPlanner::processStoreLikeOps(triton::FuncOp &funcOp) {
412368
if (!tiled) {
413369
// Use CTA tiling of the first store-like op as global CTA tiling
414370
CTALayout = ttg::getCTALayout(tensorTy.getEncoding());
415-
setTiling(CTALayout.getCTAsPerCGA());
371+
markTiled();
416372
}
417373
auto newLayout = replaceCTALayout(
418374
cast<ttg::DistributedEncodingTrait>(tensorTy.getEncoding()),
@@ -421,11 +377,8 @@ void CTAPlanner::processStoreLikeOps(triton::FuncOp &funcOp) {
421377
}
422378
}
423379

424-
// If all store-like ops are processing scalar values and no ReduceOp is
425-
// found, we can conclude that this is an all-scalar computation, since
426-
// ReduceOp is the only op that converts tensor values to scalar values.
427380
if (!tiled)
428-
setTiling({1, 1, 1});
381+
markTiled();
429382
}
430383

431384
bool CTAPlanner::propagate(CastOp cast) {
@@ -1042,8 +995,6 @@ bool CTAPlanner::processMultiUsersForward(Value castResult, CastOp cast) {
1042995
} // anonymous namespace
1043996

1044997
struct PlanCTAPass : public impl::TritonGPUPlanCTAPassBase<PlanCTAPass> {
1045-
PlanCTAPass(ClusterInfo *clusterInfo_ = nullptr)
1046-
: clusterInfo(clusterInfo_) {}
1047998
void runOnOperation() override {
1048999
ModuleOp mod = getOperation();
10491000

@@ -1052,7 +1003,7 @@ struct PlanCTAPass : public impl::TritonGPUPlanCTAPassBase<PlanCTAPass> {
10521003
return;
10531004

10541005
mod.walk([&](triton::FuncOp funcOp) {
1055-
CTAPlanner planner(clusterInfo);
1006+
CTAPlanner planner;
10561007
planner.run(funcOp);
10571008

10581009
// FIXME: Clone funcOp so that the IR change can be identified after
@@ -1064,13 +1015,10 @@ struct PlanCTAPass : public impl::TritonGPUPlanCTAPassBase<PlanCTAPass> {
10641015
funcOp.erase();
10651016
});
10661017
}
1067-
1068-
ClusterInfo *clusterInfo;
10691018
};
10701019

1071-
std::unique_ptr<Pass>
1072-
createTritonNvidiaGPUPlanCTAPass(ClusterInfo *clusterInfo) {
1073-
return std::make_unique<PlanCTAPass>(clusterInfo);
1020+
std::unique_ptr<Pass> createTritonNvidiaGPUPlanCTAPass() {
1021+
return std::make_unique<PlanCTAPass>();
10741022
}
10751023

10761024
} // namespace nvidia_gpu

0 commit comments

Comments
 (0)