Skip to content

Commit 976c1a1

Browse files
committed
fixup default order and more restrictive selection
format + remove debug prints fixup
1 parent 0f7eb06 commit 976c1a1

File tree

1 file changed

+9
-22
lines changed

1 file changed

+9
-22
lines changed

third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp

Lines changed: 9 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@ getWarpsPerTile(tt::DotOp dotOp,
6464
ceil<uint32_t>(dpasCap.repeatCount, dpasCap.executionSize);
6565
uint32_t colRowRatio =
6666
ceil<uint32_t>(dpasCap.executionSize, dpasCap.repeatCount);
67-
llvm::errs() << "rowColRation: " << rowColRatio << ", colRowRatio: " << colRowRatio << ", ret: " << ret[0] << ", " << ret[1] << "\n";
6867

6968
int rowDim = order[rank - 2], colDim = order[rank - 1];
7069
do {
@@ -80,7 +79,6 @@ getWarpsPerTile(tt::DotOp dotOp,
8079
ret[colDim] *= 2;
8180
}
8281
} while (true);
83-
8482
return ret;
8583
}
8684

@@ -120,29 +118,18 @@ class BlockedToDPAS : public OpRewritePattern<tt::DotOp> {
120118
unsigned opsPerChan =
121119
ttg::intel::DpasEncodingAttr::getOpsPerChannel(elemType);
122120

123-
SmallVector<unsigned> order = {1, 0}; // TODO: acceptable default arg?
124-
// llvm::errs() << "a: " << a << "\n";
121+
SmallVector<unsigned> order = {0, 1};
125122
Operation *aOp = a.getDefiningOp();
126-
if (aOp) {
127-
// llvm::errs() << "Processing a op: " << *aOp << "\n";
128-
Attribute layout;
129-
if (isa<ttg::ConvertLayoutOp>(aOp)) {
130-
// TODO: convertlayoutop converts the order to match dpas, so we need to
131-
// "look through" the conversion. is there a way to prevent the
132-
// conversion in the first place?
133-
assert(aOp->getNumOperands() == 1);
134-
layout =
135-
cast<RankedTensorType>(aOp->getOperand(0).getType()).getEncoding();
136-
} else {
137-
assert(aOp->getNumResults() == 1);
138-
layout =
139-
cast<RankedTensorType>(aOp->getResult(0).getType()).getEncoding();
140-
}
123+
if (aOp && isa<ttg::ConvertLayoutOp>(aOp)) {
124+
auto valueToConvert = aOp->getOperand(0);
125+
aOp = valueToConvert.getDefiningOp();
126+
}
127+
if (aOp && isa<tt::LoadOp>(aOp)) {
128+
assert(aOp->getNumResults() == 1);
129+
Attribute layout =
130+
cast<RankedTensorType>(aOp->getResult(0).getType()).getEncoding();
141131
order = triton::gpu::getOrder(layout);
142-
} else {
143-
// llvm::errs() << "no A op for A: " << a << "\n";
144132
}
145-
llvm::errs() << "order: " << order[0] << ", " << order[1] << "\n";
146133

147134
SmallVector<unsigned> warpsPerTile =
148135
getWarpsPerTile(dotOp, dpasCap, retShape, numWarps, order);

0 commit comments

Comments
 (0)