@@ -64,7 +64,6 @@ getWarpsPerTile(tt::DotOp dotOp,
6464 ceil<uint32_t >(dpasCap.repeatCount , dpasCap.executionSize );
6565 uint32_t colRowRatio =
6666 ceil<uint32_t >(dpasCap.executionSize , dpasCap.repeatCount );
67- llvm::errs () << " rowColRation: " << rowColRatio << " , colRowRatio: " << colRowRatio << " , ret: " << ret[0 ] << " , " << ret[1 ] << " \n " ;
6867
6968 int rowDim = order[rank - 2 ], colDim = order[rank - 1 ];
7069 do {
@@ -80,7 +79,6 @@ getWarpsPerTile(tt::DotOp dotOp,
8079 ret[colDim] *= 2 ;
8180 }
8281 } while (true );
83-
8482 return ret;
8583}
8684
@@ -120,29 +118,18 @@ class BlockedToDPAS : public OpRewritePattern<tt::DotOp> {
120118 unsigned opsPerChan =
121119 ttg::intel::DpasEncodingAttr::getOpsPerChannel (elemType);
122120
123- SmallVector<unsigned > order = {1 , 0 }; // TODO: acceptable default arg?
124- // llvm::errs() << "a: " << a << "\n";
121+ SmallVector<unsigned > order = {0 , 1 };
125122 Operation *aOp = a.getDefiningOp ();
126- if (aOp) {
127- // llvm::errs() << "Processing a op: " << *aOp << "\n";
128- Attribute layout;
129- if (isa<ttg::ConvertLayoutOp>(aOp)) {
130- // TODO: convertlayoutop converts the order to match dpas, so we need to
131- // "look through" the conversion. is there a way to prevent the
132- // conversion in the first place?
133- assert (aOp->getNumOperands () == 1 );
134- layout =
135- cast<RankedTensorType>(aOp->getOperand (0 ).getType ()).getEncoding ();
136- } else {
137- assert (aOp->getNumResults () == 1 );
138- layout =
139- cast<RankedTensorType>(aOp->getResult (0 ).getType ()).getEncoding ();
140- }
123+ if (aOp && isa<ttg::ConvertLayoutOp>(aOp)) {
124+ auto valueToConvert = aOp->getOperand (0 );
125+ aOp = valueToConvert.getDefiningOp ();
126+ }
127+ if (aOp && isa<tt::LoadOp>(aOp)) {
128+ assert (aOp->getNumResults () == 1 );
129+ Attribute layout =
130+ cast<RankedTensorType>(aOp->getResult (0 ).getType ()).getEncoding ();
141131 order = triton::gpu::getOrder (layout);
142- } else {
143- // llvm::errs() << "no A op for A: " << a << "\n";
144132 }
145- llvm::errs () << " order: " << order[0 ] << " , " << order[1 ] << " \n " ;
146133
147134 SmallVector<unsigned > warpsPerTile =
148135 getWarpsPerTile (dotOp, dpasCap, retShape, numWarps, order);
0 commit comments