@@ -343,6 +343,15 @@ struct BlockIOConversionBase : public LoadStoreConversionBase {
343343 : getDotEncoding (tensorTy).value ().getParent ());
344344 }
345345
346+ static RankedTensorType getDpasTypeFromCVTOp (Value opResult) {
347+ for (OpOperand user : opResult.getUsers ()) {
348+ if (auto cvt = dyn_cast<ConvertLayoutOp>(user.getOwner ())) {
349+ return cast<RankedTensorType>(cvt.getResult ().getType ());
350+ }
351+ }
352+ llvm_unreachable (" expected to find a cvt op with dpas layout" );
353+ }
354+
346355 // Returns the pitch (stride in bytes) of \p ptr.
347356 Value getPitch (ConversionPatternRewriter &rewriter, Value ptr,
348357 const std::map<SmallVector<unsigned >, Value> &ptrs,
@@ -1418,16 +1427,6 @@ struct LoadOpConversion
14181427
14191428 const bool memoryRowMajor = isMemoryRowMajor (op);
14201429
1421- auto getDpasTypeFromCVTOp = [&](Value opResult) -> RankedTensorType {
1422- for (OpOperand user : opResult.getUsers ()) {
1423- if (auto cvt = dyn_cast<ConvertLayoutOp>(user.getOwner ())) {
1424- return cast<RankedTensorType>(cvt.getResult ().getType ());
1425- // return getDpasLayout(cvt.getResult().getType());
1426- }
1427- }
1428- llvm_unreachable (" expected to find a cvt op with dpas layout" );
1429- };
1430-
14311430 auto dpasTensorType = hasSubgroup2DBlockEncoding (tensorType)
14321431 ? getDpasTypeFromCVTOp (op.getResult ())
14331432 : tensorType;
@@ -2213,6 +2212,8 @@ struct LoadOpConversion
22132212 }
22142213
22152214 Type llvmResultStructTy = typeConverter->convertType (op.getType ());
2215+ LLVM_DEBUG (llvm::dbgs () << " Packing load result in struct "
2216+ << llvmResultStructTy << " \n " );
22162217 Value resultStruct = packLLElements (loc, typeConverter, unpackedLoadedVals,
22172218 rewriter, llvmResultStructTy);
22182219 rewriter.replaceOp (op, {resultStruct});
@@ -2235,10 +2236,16 @@ struct LoadOpConversion
22352236 Value mask = op.getMask ();
22362237 Value llMask = adaptor.getMask ();
22372238
2239+ auto opType = op.getType ();
2240+ // TODO: Override the OpType since conversion is still happening during Load
2241+ // lowering. Once we materialize ConvertLayoutOp this can be removed.
2242+ if (auto tensorTy = dyn_cast<RankedTensorType>(opType);
2243+ hasSubgroup2DBlockEncoding (tensorTy))
2244+ opType = getDpasTypeFromCVTOp (op.getResult ());
2245+
22382246 // Determine the vectorization size
2239- Type valueElemTy =
2240- typeConverter->convertType (getElementTypeOrSelf (op.getType ()));
2241- unsigned numElems = getTotalElemsPerThread (op.getType ());
2247+ Type valueElemTy = typeConverter->convertType (getElementTypeOrSelf (opType));
2248+ unsigned numElems = getTotalElemsPerThread (opType);
22422249 unsigned vec = getVectorSize (ptr);
22432250 if (llMask)
22442251 vec = std::min<size_t >(vec, getMaskAlignment (mask));
@@ -2249,7 +2256,7 @@ struct LoadOpConversion
22492256
22502257 if (isTensorPointerType (ptr.getType ())) {
22512258 // fallback to gather load.
2252- auto tensorType = cast<RankedTensorType>(op. getType () );
2259+ auto tensorType = cast<RankedTensorType>(opType );
22532260 std::tie (ptrElems, maskElems, otherElems) = convertBlockPtrToTensorOfPtr (
22542261 loc, adaptor.getPtr (), tensorType, valueElemTy, rewriter,
22552262 op.getBoundaryCheck (), op.getPadding ());
@@ -2396,7 +2403,7 @@ struct LoadOpConversion
23962403 }
23972404 } // end vec
23982405
2399- Type llvmResultStructTy = typeConverter->convertType (op. getType () );
2406+ Type llvmResultStructTy = typeConverter->convertType (opType );
24002407 Value resultStruct = packLLElements (loc, typeConverter, loadedVals,
24012408 rewriter, llvmResultStructTy);
24022409 rewriter.replaceOp (op, {resultStruct});
0 commit comments