@@ -116,6 +116,35 @@ class BlockedToDPAS : public OpRewritePattern<tt::DotOp> {
116116 Type elemType = oldAType.getElementType ();
117117 unsigned opsPerChan =
118118 ttg::intel::DpasEncodingAttr::getOpsPerChannel (elemType);
119+
120+ // We are upcasting FP8 to FP16
121+ if (oldAType.getElementType ().isFloat8E5M2 () ||
122+ oldAType.getElementType ().isFloat8E4M3FN ())
123+ dpasElemBitWidths = 2 * dpasElemBitWidths;
124+
125+ // now we can get the order from the a defining op
126+
127+ llvm::errs () << " oldAType: " << oldAType << " \n " ;
128+ llvm::errs () << " oldBType: " << oldBType << " \n " ;
129+
130+ llvm::errs () << " a: " << a << " \n " ;
131+ llvm::errs () << " a defining op: " << *a.getDefiningOp () << " \n " ;
132+
133+ SmallVector<unsigned > order;
134+ Operation* aOp = a.getDefiningOp ();
135+ if (isa<ttg::ConvertLayoutOp>(aOp)) {
136+ assert (aOp->getNumOperands () == 1 );
137+ auto aLoad = aOp->getOperand (0 );
138+ order = triton::gpu::getOrder (cast<RankedTensorType>(aLoad.getType ()).getEncoding ());
139+ } else {
140+ assert (isa<tt::LoadOp>(aOp) && " expecting load input to DPAS" );
141+ order = triton::gpu::getOrder (cast<RankedTensorType>(aLoad.getType ()).getEncoding ());
142+ }
143+ // order = triton::gpu::getOrder(a.getDefiningOp().getEncoding());
144+ llvm::errs () << " a load order: " << order[0 ] << " , " << order[1 ] << " \n " ;
145+
146+ // now find the fast changing dimension from the order
147+
119148 SmallVector<unsigned > warpsPerTile =
120149 getWarpsPerTile (dotOp, dpasCap, retShape, numWarps);
121150 size_t rank = retShape.size ();
0 commit comments