@@ -120,15 +120,28 @@ class BlockedToDPAS : public OpRewritePattern<tt::DotOp> {
120120 ttg::intel::DpasEncodingAttr::getOpsPerChannel (elemType);
121121
122122 SmallVector<unsigned > order = {1 , 0 }; // TODO: acceptable default arg?
123- llvm::errs () << " a: " << a << " \n " ;
123+ // llvm::errs() << "a: " << a << "\n";
124124 Operation *aOp = a.getDefiningOp ();
125125 if (aOp) {
126- llvm::errs () << " Processing a op: " << *aOp << " \n " ;
127- assert (aOp->getNumResults () == 1 );
128- auto ret = aOp->getResult (0 );
126+ // llvm::errs() << "Processing a op: " << *aOp << "\n";
127+ Attribute layout;
128+ if (isa<ttg::ConvertLayoutOp>(aOp)) {
129+ // TODO: convertlayoutop converts the order to match dpas, so we need to
130+ // "look through" the conversion. is there a way to prevent the
131+ // conversion in the first place?
132+ assert (aOp->getNumOperands () == 1 );
133+ layout =
134+ cast<RankedTensorType>(aOp->getOperand (0 ).getType ()).getEncoding ();
135+ } else {
136+ assert (aOp->getNumResults () == 1 );
137+ layout =
138+ cast<RankedTensorType>(aOp->getResult (0 ).getType ()).getEncoding ();
139+ }
140+ order = triton::gpu::getOrder (layout);
129141 } else {
130- llvm::errs () << " no A op for A: " << a << " \n " ;
142+ // llvm::errs() << "no A op for A: " << a << "\n";
131143 }
144+ // llvm::errs() << "order: " << order[0] << ", " << order[1] << "\n";
132145
133146 SmallVector<unsigned > warpsPerTile =
134147 getWarpsPerTile (dotOp, dpasCap, retShape, numWarps, order);
0 commit comments