@@ -123,70 +123,13 @@ class BlockedToDPAS : public OpRewritePattern<tt::DotOp> {
123123 oldAType.getElementType ().isFloat8E4M3FN ())
124124 dpasElemBitWidths = 2 * dpasElemBitWidths;
125125
126- SmallVector<unsigned > order = {1 , 0 }; // TODO: acceptable default arg?
126+ SmallVector<unsigned > order = {1 , 0 }; // TODO: acceptable default arg?
127127 llvm::errs () << " a: " << a << " \n " ;
128- Operation* aOp = a.getDefiningOp ();
128+ Operation * aOp = a.getDefiningOp ();
129129 if (aOp) {
130130 llvm::errs () << " Processing a op: " << *aOp << " \n " ;
131- #if 0
132- Operation *aOp{nullptr};
133- if (auto arg = dyn_cast<BlockArgument>(a)) {
134- unsigned argNum = arg.getArgNumber();
135- Operation *argOwner = a.getParentBlock()->getParentOp();
136-
137- if (auto forOp = dyn_cast<scf::ForOp>(argOwner)) {
138- auto operand = forOp.getOperand(argNum + forOp.getNumControlOperands() - 1);
139- aOp = operand.getDefiningOp();
140- } else if (auto funcOp = dyn_cast<FunctionOpInterface>(argOwner)) {
141- #if 1
142- llvm::errs() << "func arg: " << funcOp.getArgument(argNum) << "\n";
143- aOp = funcOp.getArgument(argNum).getDefiningOp();
144- #else
145- llvm::errs() << "funcOp num args: " << funcOp.getNumArguments() << "\n";
146- llvm::errs() << "arg number: " << argNum << "\n";
147- llvm::errs() << "func op at arg num: " << funcOp.getArgument(argNum) << "\n";
148- llvm::errs() << "func op at arg num - 1: " << funcOp.getArgument(argNum -1) << "\n";
149- llvm::errs() << "func op at arg num - 2: " << funcOp.getArgument(argNum -2) << "\n";
150-
151- llvm::errs() << "funcOp: " << funcOp << "\n";
152- assert(false && "funcOp!");
153- #endif
154- } else {
155- llvm_unreachable("Unable to parse dpas op argument");
156- }
157- assert(aOp && "failed to get defining operation for DPAS A value");
158- #if 0
159- llvm::errs() << "arg: " << arg << "\n";
160- // TODO
161- aOp = arg.getDefiningOp();
162- if (aOp) {
163- llvm::errs() << "a op from arg: " << *aOp << "\n";
164- } else {
165- assert(false && "no aOp!");
166- }
167- #endif
168- } else {
169- aOp = a.getDefiningOp();
170- }
171- llvm::errs() << "Broke on aOP: " << *aOp << "\n";
172- #endif
173- #if 1
174131 assert (aOp->getNumResults () == 1 );
175132 auto ret = aOp->getResult (0 );
176- #else
177- if (isa<ttg::ConvertLayoutOp>(aOp)) {
178- assert(aOp->getNumOperands() == 1);
179- auto aLoad = aOp->getOperand(0);
180- order = triton::gpu::getOrder(
181- cast<RankedTensorType>(aLoad.getType()).getEncoding());
182- } else {
183- assert(isa<tt::LoadOp>(aOp) && "expecting load input to DPAS");
184- assert(aOp->getNumResults() == 1);
185- auto ret = aOp->getResult(0);
186- order = triton::gpu::getOrder(
187- cast<RankedTensorType>(ret.getType()).getEncoding());
188- }
189- #endif
190133 } else {
191134 llvm::errs () << " no A op for A: " << a << " \n " ;
192135 }
0 commit comments