@@ -7,10 +7,8 @@ using namespace mlir;
77using ValueTable = std::map<std::array<int , 3 >, Value>;
88using ::mlir::LLVM::delinearize;
99using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
10- using ::mlir::LLVM::getStridesFromShapeAndOrder;
1110using ::mlir::triton::gpu::DotOperandEncodingAttr;
1211using ::mlir::triton::gpu::getContigPerThread;
13- using ::mlir::triton::gpu::getOrder;
1412using ::mlir::triton::gpu::getShapePerCTA;
1513using ::mlir::triton::gpu::getSizePerThread;
1614using ::mlir::triton::gpu::getTotalElemsPerThread;
@@ -608,12 +606,11 @@ getLoadMatrixFn(MemDescType descTy, const SharedMemoryObject &smemObj,
608606 std::max<int >(shapePerCTA[2 ] / mmaLayout.getWarpsPerCTA ()[2 ], 8 );
609607 // (a, b) is the coordinate.
610608 auto load = [=, &rewriter, &vals](int batch, int a, int b) {
611- MMA16816SmemLoader loader (nPerWarp, warpsPerTile, sharedLayout.getOrder (),
612- mmaLayout.getWarpsPerCTA (), kOrder , kWidth ,
613- smemObj.strides , shapePerCTA /* tileShape*/ ,
614- instrShape, matShape, multiDimWarpId, perPhase,
615- maxPhase, elemBytes, mmaElemBytes, isHopper,
616- rewriter, typeConverter, loc);
609+ MMA16816SmemLoader loader (
610+ nPerWarp, warpsPerTile, order, mmaLayout.getWarpsPerCTA (), kOrder ,
611+ kWidth , smemObj.strides , shapePerCTA /* tileShape*/ , instrShape,
612+ matShape, multiDimWarpId, perPhase, maxPhase, elemBytes, mmaElemBytes,
613+ isHopper, rewriter, typeConverter, loc);
617614 // Offset of a slice within the original tensor in shared memory
618615 Value cSwizzleOffset = smemObj.getCSwizzleOffset (order[0 ]);
619616 SmallVector<Value> offs = loader.computeOffsets (lane, cSwizzleOffset);
0 commit comments