@@ -80,7 +80,6 @@ getWarpsPerTile(tt::DotOp dotOp,
8080 ret[colDim] *= 2 ;
8181 }
8282 } while (true );
83-
8483 return ret;
8584}
8685
@@ -120,27 +119,18 @@ class BlockedToDPAS : public OpRewritePattern<tt::DotOp> {
120119 unsigned opsPerChan =
121120 ttg::intel::DpasEncodingAttr::getOpsPerChannel (elemType);
122121
123- SmallVector<unsigned > order = {1 , 0 }; // TODO: acceptable default arg?
124- // llvm::errs() << "a: " << a << "\n";
122+ SmallVector<unsigned > order = {0 , 1 };
125123 Operation *aOp = a.getDefiningOp ();
126- if (aOp) {
127- // llvm::errs() << "Processing a op: " << *aOp << "\n";
124+ if (isa<ttg::ConvertLayoutOp>(aOp)) {
125+ auto valueToConvert = aOp->getOperand (0 );
126+ aOp = valueToConvert.getDefiningOp ();
127+ }
128+ if (aOp && isa<tt::LoadOp>(aOp)) {
128129 Attribute layout;
129- if (isa<ttg::ConvertLayoutOp>(aOp)) {
130- // TODO: convertlayoutop converts the order to match dpas, so we need to
131- // "look through" the conversion. is there a way to prevent the
132- // conversion in the first place?
133- assert (aOp->getNumOperands () == 1 );
134- layout =
135- cast<RankedTensorType>(aOp->getOperand (0 ).getType ()).getEncoding ();
136- } else {
137130 assert (aOp->getNumResults () == 1 );
138131 layout =
139132 cast<RankedTensorType>(aOp->getResult (0 ).getType ()).getEncoding ();
140- }
141133 order = triton::gpu::getOrder (layout);
142- } else {
143- // llvm::errs() << "no A op for A: " << a << "\n";
144134 }
145135 llvm::errs () << " order: " << order[0 ] << " , " << order[1 ] << " \n " ;
146136
0 commit comments