@@ -84,8 +84,7 @@ replaceCTALayout(ttg::DistributedEncodingTrait layout,
8484
8585class CTAPlanner {
8686public:
87- CTAPlanner (ClusterInfo *clusterInfo_);
88- ~CTAPlanner ();
87+ CTAPlanner ();
8988
9089 void run (triton::FuncOp &funcOp);
9190
@@ -95,7 +94,6 @@ class CTAPlanner {
9594 bool isBackward (CastOp cast) const ;
9695 bool isForward (CastOp cast) const ;
9796
98- void setTiling (llvm::ArrayRef<unsigned > CTAsPerCGA);
9997 bool processDot (triton::FuncOp &funcOp);
10098 bool processReduce (triton::FuncOp &funcOp);
10199 void processStoreLikeOps (triton::FuncOp &funcOp);
@@ -149,39 +147,17 @@ class CTAPlanner {
149147 bool processMultiUsersBackward (Value input, CastOp cast);
150148 bool processMultiUsersForward (Value output, CastOp cast);
151149
152- // This flag indicates whether clusterInfo needs to be deleted in the
153- // destructor of CTAPlanner. The flag `ownInfo` is set to false when a
154- // non-null pointer to clusterInfo is passed to the constructor of CTAPlanner.
155- // Otherwise, a self-managed ClusterInfo will be created and the ownInfo will
156- // be set to true.
157- bool ownInfo;
158- ClusterInfo *clusterInfo;
159- bool tiled;
150+ void markTiled ();
151+
160152 unsigned step;
161153 unsigned stepUnchanged;
154+ bool tiled;
162155 std::queue<CastOp> queue;
163156};
164157
165- CTAPlanner::CTAPlanner (ClusterInfo *clusterInfo_)
166- : ownInfo(false ), clusterInfo(clusterInfo_), tiled(false ), step(0 ),
167- stepUnchanged (0 ) {
168- if (clusterInfo == nullptr ) {
169- clusterInfo = new ClusterInfo ();
170- ownInfo = true ;
171- }
172- }
173-
174- CTAPlanner::~CTAPlanner () {
175- if (ownInfo) {
176- delete clusterInfo;
177- // Actually not necessary but safer
178- ownInfo = false ;
179- clusterInfo = nullptr ;
180- }
181- }
158+ CTAPlanner::CTAPlanner () : step(0 ), stepUnchanged(0 ), tiled(false ) {}
182159
183160void CTAPlanner::run (triton::FuncOp &funcOp) {
184- assert (!tiled && " Please create a new CTAPlanner" );
185161 static const unsigned maxSteps = 10000 ;
186162
187163 auto nextStep = [&]() {
@@ -232,29 +208,9 @@ bool CTAPlanner::isForward(CastOp cast) const {
232208 return cast->getAttrOfType <StringAttr>(" direction" ) == " forward" ;
233209}
234210
235- void CTAPlanner::setTiling (llvm::ArrayRef<unsigned > CTAsPerCGA) {
236- assert (!tiled && " CTA tiling is already determinted" );
237- assert (clusterInfo && " ClusterInfo pointer is null" );
211+ void CTAPlanner::markTiled () {
212+ assert (!tiled && " CTA tiling is already determined" );
238213 tiled = true ;
239- unsigned numCTAs = 1 ;
240- for (unsigned cta : CTAsPerCGA)
241- numCTAs *= cta;
242- if (numCTAs == 2 ) {
243- // For 2 CTAs always use 2x1x1.
244- // TODO: can we always serialize the CTAs on X dimension?
245- clusterInfo->clusterDimX = 2 ;
246- return ;
247- }
248-
249- if (CTAsPerCGA.size () > 0 )
250- clusterInfo->clusterDimX = CTAsPerCGA[0 ];
251- if (CTAsPerCGA.size () > 1 )
252- clusterInfo->clusterDimY = CTAsPerCGA[1 ];
253- if (CTAsPerCGA.size () > 2 )
254- clusterInfo->clusterDimZ = CTAsPerCGA[2 ];
255- for (auto i = 3 ; i < CTAsPerCGA.size (); ++i)
256- if (CTAsPerCGA[i] != 1 )
257- llvm::report_fatal_error (" tiling > 3 dims is not implemented" );
258214}
259215
260216bool CTAPlanner::processDot (triton::FuncOp &funcOp) {
@@ -297,7 +253,7 @@ bool CTAPlanner::processDot(triton::FuncOp &funcOp) {
297253 unsigned splitM, splitN;
298254 std::tie (splitM, splitN) = getCTATiling (M, N, K, ttg::getNumCTAs (dLayout));
299255 // FIXME: Should consider IR with more than one DotOps
300- setTiling ({splitM, splitN, 1 } );
256+ markTiled ( );
301257
302258 OpBuilder builder (dot);
303259 auto numThreads = ttg::lookupThreadsPerWarp (builder);
@@ -373,7 +329,7 @@ bool CTAPlanner::processReduce(triton::FuncOp &funcOp) {
373329 auto CTALayout =
374330 ttg::CTALayoutAttr::get (context, CTAsPerCGA, CTASplitNum, CTAOrder);
375331 if (!tiled)
376- setTiling (CTALayout. getCTAsPerCGA () );
332+ markTiled ( );
377333 auto newSrcLayout =
378334 replaceCTALayout (cast<ttg::DistributedEncodingTrait>(srcLayout),
379335 srcShape, numWarps, CTALayout);
@@ -389,7 +345,7 @@ bool CTAPlanner::processReduce(triton::FuncOp &funcOp) {
389345}
390346
391347void CTAPlanner::processStoreLikeOps (triton::FuncOp &funcOp) {
392- assert (!tiled && " CTA tiling is already determinted " );
348+ assert (!tiled && " CTA tiling is already determined " );
393349
394350 llvm::SmallVector<Operation *> stores;
395351 funcOp.walk ([&](Operation *op) {
@@ -412,7 +368,7 @@ void CTAPlanner::processStoreLikeOps(triton::FuncOp &funcOp) {
412368 if (!tiled) {
413369 // Use CTA tiling of the first store-like op as global CTA tiling
414370 CTALayout = ttg::getCTALayout (tensorTy.getEncoding ());
415- setTiling (CTALayout. getCTAsPerCGA () );
371+ markTiled ( );
416372 }
417373 auto newLayout = replaceCTALayout (
418374 cast<ttg::DistributedEncodingTrait>(tensorTy.getEncoding ()),
@@ -421,11 +377,8 @@ void CTAPlanner::processStoreLikeOps(triton::FuncOp &funcOp) {
421377 }
422378 }
423379
424- // If all store-like ops are processing scalar values and no ReduceOp is
425- // found, we can conclude that this is an all-scalar computation, since
426- // ReduceOp is the only op that converts tensor values to scalar values.
427380 if (!tiled)
428- setTiling ({ 1 , 1 , 1 } );
381+ markTiled ( );
429382}
430383
431384bool CTAPlanner::propagate (CastOp cast) {
@@ -1042,8 +995,6 @@ bool CTAPlanner::processMultiUsersForward(Value castResult, CastOp cast) {
1042995} // anonymous namespace
1043996
1044997struct PlanCTAPass : public impl ::TritonGPUPlanCTAPassBase<PlanCTAPass> {
1045- PlanCTAPass (ClusterInfo *clusterInfo_ = nullptr )
1046- : clusterInfo(clusterInfo_) {}
1047998 void runOnOperation () override {
1048999 ModuleOp mod = getOperation ();
10491000
@@ -1052,7 +1003,7 @@ struct PlanCTAPass : public impl::TritonGPUPlanCTAPassBase<PlanCTAPass> {
10521003 return ;
10531004
10541005 mod.walk ([&](triton::FuncOp funcOp) {
1055- CTAPlanner planner (clusterInfo) ;
1006+ CTAPlanner planner;
10561007 planner.run (funcOp);
10571008
10581009 // FIXME: Clone funcOp so that the IR change can be identified after
@@ -1064,13 +1015,10 @@ struct PlanCTAPass : public impl::TritonGPUPlanCTAPassBase<PlanCTAPass> {
10641015 funcOp.erase ();
10651016 });
10661017 }
1067-
1068- ClusterInfo *clusterInfo;
10691018};
10701019
1071- std::unique_ptr<Pass>
1072- createTritonNvidiaGPUPlanCTAPass (ClusterInfo *clusterInfo) {
1073- return std::make_unique<PlanCTAPass>(clusterInfo);
1020+ std::unique_ptr<Pass> createTritonNvidiaGPUPlanCTAPass () {
1021+ return std::make_unique<PlanCTAPass>();
10741022}
10751023
10761024} // namespace nvidia_gpu
0 commit comments