Skip to content

Commit 0f7eb06

Browse files
committed
pick a default order to avoid problems getting the order from operations
remove debug info format fixup after rebase fixups again after rebase more debug logging
1 parent 42e73ab commit 0f7eb06

File tree

1 file changed

+23
-17
lines changed

1 file changed

+23
-17
lines changed

third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ namespace {
3131
SmallVector<unsigned>
3232
getWarpsPerTile(tt::DotOp dotOp,
3333
ttg::intel::DpasEncodingAttr::DPASCapability dpasCap,
34-
const ArrayRef<int64_t> shape, unsigned numWarps, const SmallVector<unsigned>& order) {
34+
const ArrayRef<int64_t> shape, unsigned numWarps,
35+
const SmallVector<unsigned> &order) {
3536

3637
auto filter = [&dotOp](Operation *op) {
3738
return op->getParentRegion() == dotOp->getParentRegion();
@@ -63,6 +64,7 @@ getWarpsPerTile(tt::DotOp dotOp,
6364
ceil<uint32_t>(dpasCap.repeatCount, dpasCap.executionSize);
6465
uint32_t colRowRatio =
6566
ceil<uint32_t>(dpasCap.executionSize, dpasCap.repeatCount);
67+
llvm::errs() << "rowColRation: " << rowColRatio << ", colRowRatio: " << colRowRatio << ", ret: " << ret[0] << ", " << ret[1] << "\n";
6668

6769
int rowDim = order[rank - 2], colDim = order[rank - 1];
6870
do {
@@ -118,25 +120,29 @@ class BlockedToDPAS : public OpRewritePattern<tt::DotOp> {
118120
unsigned opsPerChan =
119121
ttg::intel::DpasEncodingAttr::getOpsPerChannel(elemType);
120122

121-
// We are upcasting FP8 to FP16
122-
if (oldAType.getElementType().isFloat8E5M2() ||
123-
oldAType.getElementType().isFloat8E4M3FN())
124-
dpasElemBitWidths = 2 * dpasElemBitWidths;
125-
126-
SmallVector<unsigned> order;
123+
SmallVector<unsigned> order = {1, 0}; // TODO: acceptable default arg?
124+
// llvm::errs() << "a: " << a << "\n";
127125
Operation *aOp = a.getDefiningOp();
128-
if (isa<ttg::ConvertLayoutOp>(aOp)) {
129-
assert(aOp->getNumOperands() == 1);
130-
auto aLoad = aOp->getOperand(0);
131-
order = triton::gpu::getOrder(
132-
cast<RankedTensorType>(aLoad.getType()).getEncoding());
126+
if (aOp) {
127+
// llvm::errs() << "Processing a op: " << *aOp << "\n";
128+
Attribute layout;
129+
if (isa<ttg::ConvertLayoutOp>(aOp)) {
130+
// TODO: convertlayoutop converts the order to match dpas, so we need to
131+
// "look through" the conversion. is there a way to prevent the
132+
// conversion in the first place?
133+
assert(aOp->getNumOperands() == 1);
134+
layout =
135+
cast<RankedTensorType>(aOp->getOperand(0).getType()).getEncoding();
136+
} else {
137+
assert(aOp->getNumResults() == 1);
138+
layout =
139+
cast<RankedTensorType>(aOp->getResult(0).getType()).getEncoding();
140+
}
141+
order = triton::gpu::getOrder(layout);
133142
} else {
134-
assert(isa<tt::LoadOp>(aOp) && "expecting load input to DPAS");
135-
assert(aOp->getNumResults() == 1);
136-
auto ret = aOp->getResult(0);
137-
order = triton::gpu::getOrder(
138-
cast<RankedTensorType>(ret.getType()).getEncoding());
143+
// llvm::errs() << "no A op for A: " << a << "\n";
139144
}
145+
llvm::errs() << "order: " << order[0] << ", " << order[1] << "\n";
140146

141147
SmallVector<unsigned> warpsPerTile =
142148
getWarpsPerTile(dotOp, dpasCap, retShape, numWarps, order);

0 commit comments

Comments
 (0)