@@ -31,7 +31,8 @@ namespace {
3131SmallVector<unsigned >
3232getWarpsPerTile (tt::DotOp dotOp,
3333 ttg::intel::DpasEncodingAttr::DPASCapability dpasCap,
34- const ArrayRef<int64_t > shape, unsigned numWarps, const SmallVector<unsigned >& order) {
34+ const ArrayRef<int64_t > shape, unsigned numWarps,
35+ const SmallVector<unsigned > &order) {
3536
3637 auto filter = [&dotOp](Operation *op) {
3738 return op->getParentRegion () == dotOp->getParentRegion ();
@@ -63,6 +64,7 @@ getWarpsPerTile(tt::DotOp dotOp,
6364 ceil<uint32_t >(dpasCap.repeatCount , dpasCap.executionSize );
6465 uint32_t colRowRatio =
6566 ceil<uint32_t >(dpasCap.executionSize , dpasCap.repeatCount );
67+ llvm::errs () << " rowColRation: " << rowColRatio << " , colRowRatio: " << colRowRatio << " , ret: " << ret[0 ] << " , " << ret[1 ] << " \n " ;
6668
6769 int rowDim = order[rank - 2 ], colDim = order[rank - 1 ];
6870 do {
@@ -118,25 +120,29 @@ class BlockedToDPAS : public OpRewritePattern<tt::DotOp> {
118120 unsigned opsPerChan =
119121 ttg::intel::DpasEncodingAttr::getOpsPerChannel (elemType);
120122
121- // We are upcasting FP8 to FP16
122- if (oldAType.getElementType ().isFloat8E5M2 () ||
123- oldAType.getElementType ().isFloat8E4M3FN ())
124- dpasElemBitWidths = 2 * dpasElemBitWidths;
125-
126- SmallVector<unsigned > order;
123+ SmallVector<unsigned > order = {1 , 0 }; // TODO: acceptable default arg?
124+ // llvm::errs() << "a: " << a << "\n";
127125 Operation *aOp = a.getDefiningOp ();
128- if (isa<ttg::ConvertLayoutOp>(aOp)) {
129- assert (aOp->getNumOperands () == 1 );
130- auto aLoad = aOp->getOperand (0 );
131- order = triton::gpu::getOrder (
132- cast<RankedTensorType>(aLoad.getType ()).getEncoding ());
126+ if (aOp) {
127+ // llvm::errs() << "Processing a op: " << *aOp << "\n";
128+ Attribute layout;
129+ if (isa<ttg::ConvertLayoutOp>(aOp)) {
130+ // TODO: convertlayoutop converts the order to match dpas, so we need to
131+ // "look through" the conversion. is there a way to prevent the
132+ // conversion in the first place?
133+ assert (aOp->getNumOperands () == 1 );
134+ layout =
135+ cast<RankedTensorType>(aOp->getOperand (0 ).getType ()).getEncoding ();
136+ } else {
137+ assert (aOp->getNumResults () == 1 );
138+ layout =
139+ cast<RankedTensorType>(aOp->getResult (0 ).getType ()).getEncoding ();
140+ }
141+ order = triton::gpu::getOrder (layout);
133142 } else {
134- assert (isa<tt::LoadOp>(aOp) && " expecting load input to DPAS" );
135- assert (aOp->getNumResults () == 1 );
136- auto ret = aOp->getResult (0 );
137- order = triton::gpu::getOrder (
138- cast<RankedTensorType>(ret.getType ()).getEncoding ());
143+ // llvm::errs() << "no A op for A: " << a << "\n";
139144 }
145+ llvm::errs () << " order: " << order[0 ] << " , " << order[1 ] << " \n " ;
140146
141147 SmallVector<unsigned > warpsPerTile =
142148 getWarpsPerTile (dotOp, dpasCap, retShape, numWarps, order);
0 commit comments