Skip to content

Commit b42d979

Browse files
committed
pick a default order to avoid problems getting the order from operations
1 parent dd1bb11 commit b42d979

File tree

1 file changed

+63
-10
lines changed

1 file changed

+63
-10
lines changed

third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp

Lines changed: 63 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -123,19 +123,72 @@ class BlockedToDPAS : public OpRewritePattern<tt::DotOp> {
123123
oldAType.getElementType().isFloat8E4M3FN())
124124
dpasElemBitWidths = 2 * dpasElemBitWidths;
125125

126-
SmallVector<unsigned> order;
127-
Operation *aOp = a.getDefiningOp();
128-
if (isa<ttg::ConvertLayoutOp>(aOp)) {
129-
assert(aOp->getNumOperands() == 1);
130-
auto aLoad = aOp->getOperand(0);
131-
order = triton::gpu::getOrder(
132-
cast<RankedTensorType>(aLoad.getType()).getEncoding());
126+
SmallVector<unsigned> order = {1, 0}; // TODO: acceptable default arg?
127+
llvm::errs() << "a: " << a << "\n";
128+
Operation* aOp = a.getDefiningOp();
129+
if (aOp) {
130+
llvm::errs() << "Processing a op: " << *aOp << "\n";
131+
#if 0
132+
Operation *aOp{nullptr};
133+
if (auto arg = dyn_cast<BlockArgument>(a)) {
134+
unsigned argNum = arg.getArgNumber();
135+
Operation *argOwner = a.getParentBlock()->getParentOp();
136+
137+
if (auto forOp = dyn_cast<scf::ForOp>(argOwner)) {
138+
auto operand = forOp.getOperand(argNum + forOp.getNumControlOperands() - 1);
139+
aOp = operand.getDefiningOp();
140+
} else if (auto funcOp = dyn_cast<FunctionOpInterface>(argOwner)) {
141+
#if 1
142+
llvm::errs() << "func arg: " << funcOp.getArgument(argNum) << "\n";
143+
aOp = funcOp.getArgument(argNum).getDefiningOp();
144+
#else
145+
llvm::errs() << "funcOp num args: " << funcOp.getNumArguments() << "\n";
146+
llvm::errs() << "arg number: " << argNum << "\n";
147+
llvm::errs() << "func op at arg num: " << funcOp.getArgument(argNum) << "\n";
148+
llvm::errs() << "func op at arg num - 1: " << funcOp.getArgument(argNum -1) << "\n";
149+
llvm::errs() << "func op at arg num - 2: " << funcOp.getArgument(argNum -2) << "\n";
150+
151+
llvm::errs() << "funcOp: " << funcOp << "\n";
152+
assert(false && "funcOp!");
153+
#endif
154+
} else {
155+
llvm_unreachable("Unable to parse dpas op argument");
156+
}
157+
assert(aOp && "failed to get defining operation for DPAS A value");
158+
#if 0
159+
llvm::errs() << "arg: " << arg << "\n";
160+
// TODO
161+
aOp = arg.getDefiningOp();
162+
if (aOp) {
163+
llvm::errs() << "a op from arg: " << *aOp << "\n";
164+
} else {
165+
assert(false && "no aOp!");
166+
}
167+
#endif
133168
} else {
134-
assert(isa<tt::LoadOp>(aOp) && "expecting load input to DPAS");
169+
aOp = a.getDefiningOp();
170+
}
171+
llvm::errs() << "Broke on aOP: " << *aOp << "\n";
172+
#endif
173+
#if 1
135174
assert(aOp->getNumResults() == 1);
136175
auto ret = aOp->getResult(0);
137-
order = triton::gpu::getOrder(
138-
cast<RankedTensorType>(ret.getType()).getEncoding());
176+
#else
177+
if (isa<ttg::ConvertLayoutOp>(aOp)) {
178+
assert(aOp->getNumOperands() == 1);
179+
auto aLoad = aOp->getOperand(0);
180+
order = triton::gpu::getOrder(
181+
cast<RankedTensorType>(aLoad.getType()).getEncoding());
182+
} else {
183+
assert(isa<tt::LoadOp>(aOp) && "expecting load input to DPAS");
184+
assert(aOp->getNumResults() == 1);
185+
auto ret = aOp->getResult(0);
186+
order = triton::gpu::getOrder(
187+
cast<RankedTensorType>(ret.getType()).getEncoding());
188+
}
189+
#endif
190+
} else {
191+
llvm::errs() << "no A op for A: " << a << "\n";
139192
}
140193

141194
SmallVector<unsigned> warpsPerTile =

0 commit comments

Comments
 (0)