@@ -123,19 +123,72 @@ class BlockedToDPAS : public OpRewritePattern<tt::DotOp> {
123123 oldAType.getElementType ().isFloat8E4M3FN ())
124124 dpasElemBitWidths = 2 * dpasElemBitWidths;
125125
126- SmallVector<unsigned > order;
127- Operation *aOp = a.getDefiningOp ();
128- if (isa<ttg::ConvertLayoutOp>(aOp)) {
129- assert (aOp->getNumOperands () == 1 );
130- auto aLoad = aOp->getOperand (0 );
131- order = triton::gpu::getOrder (
132- cast<RankedTensorType>(aLoad.getType ()).getEncoding ());
126+ SmallVector<unsigned > order = {1 , 0 }; // TODO: acceptable default arg?
127+ llvm::errs () << " a: " << a << " \n " ;
128+ Operation* aOp = a.getDefiningOp ();
129+ if (aOp) {
130+ llvm::errs () << " Processing a op: " << *aOp << " \n " ;
131+ #if 0
132+ Operation *aOp{nullptr};
133+ if (auto arg = dyn_cast<BlockArgument>(a)) {
134+ unsigned argNum = arg.getArgNumber();
135+ Operation *argOwner = a.getParentBlock()->getParentOp();
136+
137+ if (auto forOp = dyn_cast<scf::ForOp>(argOwner)) {
138+ auto operand = forOp.getOperand(argNum + forOp.getNumControlOperands() - 1);
139+ aOp = operand.getDefiningOp();
140+ } else if (auto funcOp = dyn_cast<FunctionOpInterface>(argOwner)) {
141+ #if 1
142+ llvm::errs() << "func arg: " << funcOp.getArgument(argNum) << "\n";
143+ aOp = funcOp.getArgument(argNum).getDefiningOp();
144+ #else
145+ llvm::errs() << "funcOp num args: " << funcOp.getNumArguments() << "\n";
146+ llvm::errs() << "arg number: " << argNum << "\n";
147+ llvm::errs() << "func op at arg num: " << funcOp.getArgument(argNum) << "\n";
148+ llvm::errs() << "func op at arg num - 1: " << funcOp.getArgument(argNum -1) << "\n";
149+ llvm::errs() << "func op at arg num - 2: " << funcOp.getArgument(argNum -2) << "\n";
150+
151+ llvm::errs() << "funcOp: " << funcOp << "\n";
152+ assert(false && "funcOp!");
153+ #endif
154+ } else {
155+ llvm_unreachable("Unable to parse dpas op argument");
156+ }
157+ assert(aOp && "failed to get defining operation for DPAS A value");
158+ #if 0
159+ llvm::errs() << "arg: " << arg << "\n";
160+ // TODO
161+ aOp = arg.getDefiningOp();
162+ if (aOp) {
163+ llvm::errs() << "a op from arg: " << *aOp << "\n";
164+ } else {
165+ assert(false && "no aOp!");
166+ }
167+ #endif
133168 } else {
134- assert (isa<tt::LoadOp>(aOp) && " expecting load input to DPAS" );
169+ aOp = a.getDefiningOp();
170+ }
171+ llvm::errs() << "Broke on aOP: " << *aOp << "\n";
172+ #endif
173+ #if 1
135174 assert (aOp->getNumResults () == 1 );
136175 auto ret = aOp->getResult (0 );
137- order = triton::gpu::getOrder (
138- cast<RankedTensorType>(ret.getType ()).getEncoding ());
176+ #else
177+ if (isa<ttg::ConvertLayoutOp>(aOp)) {
178+ assert(aOp->getNumOperands() == 1);
179+ auto aLoad = aOp->getOperand(0);
180+ order = triton::gpu::getOrder(
181+ cast<RankedTensorType>(aLoad.getType()).getEncoding());
182+ } else {
183+ assert(isa<tt::LoadOp>(aOp) && "expecting load input to DPAS");
184+ assert(aOp->getNumResults() == 1);
185+ auto ret = aOp->getResult(0);
186+ order = triton::gpu::getOrder(
187+ cast<RankedTensorType>(ret.getType()).getEncoding());
188+ }
189+ #endif
190+ } else {
191+ llvm::errs () << " no A op for A: " << a << " \n " ;
139192 }
140193
141194 SmallVector<unsigned > warpsPerTile =
0 commit comments