Skip to content

Commit 7eee7e3

Browse files
committed
[DPAS] Pick warpsPerCTA based on fast changing axis of A matrix
1 parent c4201fa commit 7eee7e3

File tree

1 file changed

+19
-4
lines changed

1 file changed

+19
-4
lines changed

third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ namespace {
3232

3333
SmallVector<unsigned>
3434
getWarpsPerTile(tt::DotOp dotOp, ttgi::DpasEncodingAttr::DPASCapability dpasCap,
35-
const ArrayRef<int64_t> shape, unsigned numWarps) {
35+
const ArrayRef<int64_t> shape, unsigned numWarps,
36+
const SmallVector<unsigned> &order) {
3637
auto filter = [&dotOp](Operation *op) {
3738
return op->getParentRegion() == dotOp->getParentRegion();
3839
};
@@ -64,7 +65,7 @@ getWarpsPerTile(tt::DotOp dotOp, ttgi::DpasEncodingAttr::DPASCapability dpasCap,
6465
uint32_t colRowRatio =
6566
ceil<uint32_t>(dpasCap.executionSize, dpasCap.repeatCount);
6667

67-
int rowDim = rank - 2, colDim = rank - 1;
68+
int rowDim = order[rank - 2], colDim = order[rank - 1];
6869
do {
6970
if (ret[rowDim] * ret[colDim] >= numWarps)
7071
break;
@@ -78,7 +79,6 @@ getWarpsPerTile(tt::DotOp dotOp, ttgi::DpasEncodingAttr::DPASCapability dpasCap,
7879
ret[colDim] *= 2;
7980
}
8081
} while (true);
81-
8282
return ret;
8383
}
8484

@@ -115,9 +115,24 @@ class BlockedToDPAS : public OpRewritePattern<tt::DotOp> {
115115

116116
auto dpasCap = ttgi::DpasEncodingAttr::getDPASCapability(mod);
117117
Type elemType = oldAType.getElementType();
118+
118119
unsigned opsPerChan = ttgi::DpasEncodingAttr::getOpsPerChannel(elemType);
120+
121+
SmallVector<unsigned> order = {0, 1};
122+
Operation *aOp = a.getDefiningOp();
123+
if (aOp && isa<ttg::ConvertLayoutOp>(aOp)) {
124+
auto valueToConvert = aOp->getOperand(0);
125+
aOp = valueToConvert.getDefiningOp();
126+
}
127+
if (aOp && isa<tt::LoadOp>(aOp)) {
128+
assert(aOp->getNumResults() == 1);
129+
Attribute layout =
130+
cast<RankedTensorType>(aOp->getResult(0).getType()).getEncoding();
131+
order = triton::gpu::getOrder(layout);
132+
}
133+
119134
SmallVector<unsigned> warpsPerTile =
120-
getWarpsPerTile(dotOp, dpasCap, retShape, numWarps);
135+
getWarpsPerTile(dotOp, dpasCap, retShape, numWarps, order);
121136
size_t rank = retShape.size();
122137
SmallVector<unsigned> repCluster(rank, 1);
123138

0 commit comments

Comments
 (0)