Skip to content

Commit 42e73ab

Browse files
committed
use A matrix layout order when determining dpas order in accelerate matmul
format + remove debug prints
1 parent bb7ce73 commit 42e73ab

File tree

2 files changed

+11
-19
lines changed

2 files changed

+11
-19
lines changed

third_party/intel/backend/compiler.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,6 @@ def make_ttgir(mod, metadata, opt, properties):
250250
intel.passes.ttgpuir.add_rewrite_tensor_pointer(pm)
251251
intel.passes.ttgpuir.add_pipeline(pm, opt.num_stages, False)
252252

253-
254253
passes.ttgpuir.add_optimize_thread_locality(pm)
255254
passes.ttgpuir.add_optimize_dot_operands(pm, True)
256255
passes.common.add_cse(pm)

third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ namespace {
3131
SmallVector<unsigned>
3232
getWarpsPerTile(tt::DotOp dotOp,
3333
ttg::intel::DpasEncodingAttr::DPASCapability dpasCap,
34-
const ArrayRef<int64_t> shape, unsigned numWarps) {
34+
const ArrayRef<int64_t> shape, unsigned numWarps, const SmallVector<unsigned>& order) {
35+
3536
auto filter = [&dotOp](Operation *op) {
3637
return op->getParentRegion() == dotOp->getParentRegion();
3738
};
@@ -63,7 +64,7 @@ getWarpsPerTile(tt::DotOp dotOp,
6364
uint32_t colRowRatio =
6465
ceil<uint32_t>(dpasCap.executionSize, dpasCap.repeatCount);
6566

66-
int rowDim = rank - 2, colDim = rank - 1;
67+
int rowDim = order[rank - 2], colDim = order[rank - 1];
6768
do {
6869
if (ret[rowDim] * ret[colDim] >= numWarps)
6970
break;
@@ -122,31 +123,23 @@ class BlockedToDPAS : public OpRewritePattern<tt::DotOp> {
122123
oldAType.getElementType().isFloat8E4M3FN())
123124
dpasElemBitWidths = 2 * dpasElemBitWidths;
124125

125-
// now we can get the order from the a defining op
126-
127-
llvm::errs() << "oldAType: " << oldAType << "\n";
128-
llvm::errs() << "oldBType: " << oldBType << "\n";
129-
130-
llvm::errs() << "a: " << a << "\n";
131-
llvm::errs() << "a defining op: " << *a.getDefiningOp() << "\n";
132-
133126
SmallVector<unsigned> order;
134-
Operation* aOp = a.getDefiningOp();
127+
Operation *aOp = a.getDefiningOp();
135128
if (isa<ttg::ConvertLayoutOp>(aOp)) {
136129
assert(aOp->getNumOperands() == 1);
137130
auto aLoad = aOp->getOperand(0);
138-
order = triton::gpu::getOrder(cast<RankedTensorType>(aLoad.getType()).getEncoding());
131+
order = triton::gpu::getOrder(
132+
cast<RankedTensorType>(aLoad.getType()).getEncoding());
139133
} else {
140134
assert(isa<tt::LoadOp>(aOp) && "expecting load input to DPAS");
141-
order = triton::gpu::getOrder(cast<RankedTensorType>(aLoad.getType()).getEncoding());
135+
assert(aOp->getNumResults() == 1);
136+
auto ret = aOp->getResult(0);
137+
order = triton::gpu::getOrder(
138+
cast<RankedTensorType>(ret.getType()).getEncoding());
142139
}
143-
// order = triton::gpu::getOrder(a.getDefiningOp().getEncoding());
144-
llvm::errs() << "a load order: " << order[0] << ", " << order[1] << "\n";
145-
146-
// now find the fast changing dimension from the order
147140

148141
SmallVector<unsigned> warpsPerTile =
149-
getWarpsPerTile(dotOp, dpasCap, retShape, numWarps);
142+
getWarpsPerTile(dotOp, dpasCap, retShape, numWarps, order);
150143
size_t rank = retShape.size();
151144
SmallVector<unsigned> repCluster(rank, 1);
152145

0 commit comments

Comments
 (0)