Skip to content

Commit ce3f3c9

Browse files
committed
fixups again after rebase
1 parent 12afc8d commit ce3f3c9

File tree

1 file changed

+18
-5
lines changed

1 file changed

+18
-5
lines changed

third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,15 +120,28 @@ class BlockedToDPAS : public OpRewritePattern<tt::DotOp> {
120120
ttg::intel::DpasEncodingAttr::getOpsPerChannel(elemType);
121121

122122
SmallVector<unsigned> order = {1, 0}; // TODO: acceptable default arg?
123-
llvm::errs() << "a: " << a << "\n";
123+
// llvm::errs() << "a: " << a << "\n";
124124
Operation *aOp = a.getDefiningOp();
125125
if (aOp) {
126-
llvm::errs() << "Processing a op: " << *aOp << "\n";
127-
assert(aOp->getNumResults() == 1);
128-
auto ret = aOp->getResult(0);
126+
// llvm::errs() << "Processing a op: " << *aOp << "\n";
127+
Attribute layout;
128+
if (isa<ttg::ConvertLayoutOp>(aOp)) {
129+
// TODO: convertlayoutop converts the order to match dpas, so we need to
130+
// "look through" the conversion. is there a way to prevent the
131+
// conversion in the first place?
132+
assert(aOp->getNumOperands() == 1);
133+
layout =
134+
cast<RankedTensorType>(aOp->getOperand(0).getType()).getEncoding();
135+
} else {
136+
assert(aOp->getNumResults() == 1);
137+
layout =
138+
cast<RankedTensorType>(aOp->getResult(0).getType()).getEncoding();
139+
}
140+
order = triton::gpu::getOrder(layout);
129141
} else {
130-
llvm::errs() << "no A op for A: " << a << "\n";
142+
// llvm::errs() << "no A op for A: " << a << "\n";
131143
}
144+
// llvm::errs() << "order: " << order[0] << ", " << order[1] << "\n";
132145

133146
SmallVector<unsigned> warpsPerTile =
134147
getWarpsPerTile(dotOp, dpasCap, retShape, numWarps, order);

0 commit comments

Comments
 (0)