@@ -31,7 +31,8 @@ namespace {
3131SmallVector<unsigned >
3232getWarpsPerTile (tt::DotOp dotOp,
3333 ttg::intel::DpasEncodingAttr::DPASCapability dpasCap,
34- const ArrayRef<int64_t > shape, unsigned numWarps) {
34+ const ArrayRef<int64_t > shape, unsigned numWarps, const SmallVector<unsigned >& order) {
35+
3536 auto filter = [&dotOp](Operation *op) {
3637 return op->getParentRegion () == dotOp->getParentRegion ();
3738 };
@@ -63,7 +64,7 @@ getWarpsPerTile(tt::DotOp dotOp,
6364 uint32_t colRowRatio =
6465 ceil<uint32_t >(dpasCap.executionSize , dpasCap.repeatCount );
6566
66- int rowDim = rank - 2 , colDim = rank - 1 ;
67+ int rowDim = order[ rank - 2 ] , colDim = order[ rank - 1 ] ;
6768 do {
6869 if (ret[rowDim] * ret[colDim] >= numWarps)
6970 break ;
@@ -122,31 +123,23 @@ class BlockedToDPAS : public OpRewritePattern<tt::DotOp> {
122123 oldAType.getElementType ().isFloat8E4M3FN ())
123124 dpasElemBitWidths = 2 * dpasElemBitWidths;
124125
125- // now we can get the order from the a defining op
126-
127- llvm::errs () << " oldAType: " << oldAType << " \n " ;
128- llvm::errs () << " oldBType: " << oldBType << " \n " ;
129-
130- llvm::errs () << " a: " << a << " \n " ;
131- llvm::errs () << " a defining op: " << *a.getDefiningOp () << " \n " ;
132-
133126 SmallVector<unsigned > order;
134- Operation* aOp = a.getDefiningOp ();
127+ Operation * aOp = a.getDefiningOp ();
135128 if (isa<ttg::ConvertLayoutOp>(aOp)) {
136129 assert (aOp->getNumOperands () == 1 );
137130 auto aLoad = aOp->getOperand (0 );
138- order = triton::gpu::getOrder (cast<RankedTensorType>(aLoad.getType ()).getEncoding ());
131+ order = triton::gpu::getOrder (
132+ cast<RankedTensorType>(aLoad.getType ()).getEncoding ());
139133 } else {
140134 assert (isa<tt::LoadOp>(aOp) && " expecting load input to DPAS" );
141- order = triton::gpu::getOrder (cast<RankedTensorType>(aLoad.getType ()).getEncoding ());
135+ assert (aOp->getNumResults () == 1 );
136+ auto ret = aOp->getResult (0 );
137+ order = triton::gpu::getOrder (
138+ cast<RankedTensorType>(ret.getType ()).getEncoding ());
142139 }
143- // order = triton::gpu::getOrder(a.getDefiningOp().getEncoding());
144- llvm::errs () << " a load order: " << order[0 ] << " , " << order[1 ] << " \n " ;
145-
146- // now find the fast changing dimension from the order
147140
148141 SmallVector<unsigned > warpsPerTile =
149- getWarpsPerTile (dotOp, dpasCap, retShape, numWarps);
142+ getWarpsPerTile (dotOp, dpasCap, retShape, numWarps, order );
150143 size_t rank = retShape.size ();
151144 SmallVector<unsigned > repCluster (rank, 1 );
152145
0 commit comments