Skip to content

Commit bb7ce73

Browse files
committed
select warps per cta based on fast changing dim of A matrix 1/?
1 parent 44b46e5 commit bb7ce73

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,35 @@ class BlockedToDPAS : public OpRewritePattern<tt::DotOp> {
116116
Type elemType = oldAType.getElementType();
117117
unsigned opsPerChan =
118118
ttg::intel::DpasEncodingAttr::getOpsPerChannel(elemType);
119+
120+
// We are upcasting FP8 to FP16
121+
if (oldAType.getElementType().isFloat8E5M2() ||
122+
oldAType.getElementType().isFloat8E4M3FN())
123+
dpasElemBitWidths = 2 * dpasElemBitWidths;
124+
125+
// now we can get the order from the a defining op
126+
127+
llvm::errs() << "oldAType: " << oldAType << "\n";
128+
llvm::errs() << "oldBType: " << oldBType << "\n";
129+
130+
llvm::errs() << "a: " << a << "\n";
131+
llvm::errs() << "a defining op: " << *a.getDefiningOp() << "\n";
132+
133+
SmallVector<unsigned> order;
134+
Operation* aOp = a.getDefiningOp();
135+
if (isa<ttg::ConvertLayoutOp>(aOp)) {
136+
assert(aOp->getNumOperands() == 1);
137+
auto aLoad = aOp->getOperand(0);
138+
order = triton::gpu::getOrder(cast<RankedTensorType>(aLoad.getType()).getEncoding());
139+
} else {
140+
assert(isa<tt::LoadOp>(aOp) && "expecting load input to DPAS");
141+
order = triton::gpu::getOrder(cast<RankedTensorType>(aLoad.getType()).getEncoding());
142+
}
143+
// order = triton::gpu::getOrder(a.getDefiningOp().getEncoding());
144+
llvm::errs() << "a load order: " << order[0] << ", " << order[1] << "\n";
145+
146+
// now find the fast changing dimension from the order
147+
119148
SmallVector<unsigned> warpsPerTile =
120149
getWarpsPerTile(dotOp, dpasCap, retShape, numWarps);
121150
size_t rank = retShape.size();

0 commit comments

Comments
 (0)