@@ -302,7 +302,8 @@ struct BlockIOConversionBase : public LoadStoreConversionBase {
302302
303303 // Only lower loadOp with dpas layout encoding.
304304 auto tensorTy = cast<RankedTensorType>(op.getType ());
305- return hasDpasEncoding (tensorTy) || hasDotDpasEncoding (tensorTy);
305+ return hasDpasEncoding (tensorTy) || hasDotDpasEncoding (tensorTy) ||
306+ hasSubgroup2DBlockEncoding (tensorTy);
306307 }
307308
308309 template <
@@ -1416,12 +1417,31 @@ struct LoadOpConversion
14161417 auto tensorType = cast<RankedTensorType>(resultType);
14171418
14181419 const bool memoryRowMajor = isMemoryRowMajor (op);
1419- DpasEncodingAttr::OpIdx opIdx = getOpIdx (tensorType);
1420+
1421+ auto getDpasTypeFromCVTOp = [&](Value opResult) -> RankedTensorType {
1422+ for (OpOperand user : opResult.getUsers ()) {
1423+ if (auto cvt = dyn_cast<ConvertLayoutOp>(user.getOwner ())) {
1424+ return cast<RankedTensorType>(cvt.getResult ().getType ());
1425+ // return getDpasLayout(cvt.getResult().getType());
1426+ }
1427+ }
1428+ llvm_unreachable (" expected to find a cvt op with dpas layout" );
1429+ };
1430+
1431+ auto dpasTensorType = hasSubgroup2DBlockEncoding (tensorType)
1432+ ? getDpasTypeFromCVTOp (op.getResult ())
1433+ : tensorType;
1434+ llvm::errs () << " using dpas tensor type: " << dpasTensorType << " \n " ;
1435+ DpasEncodingAttr dpasLayout = getDpasLayout (dpasTensorType);
1436+
1437+ DpasEncodingAttr::OpIdx opIdx = getOpIdx (dpasTensorType);
14201438
14211439 LLVM_DEBUG (llvm::dbgs () << " Tensor type for op " << int (opIdx) << " : "
14221440 << tensorType << " \n " );
14231441
14241442 Attribute encoding = tensorType.getEncoding ();
1443+ // TODO: this gives us the linear layour corresponding
1444+ // to the subgroup 2d block encoding, not the dpas encoding...
14251445 std::optional<LinearLayout> llEncoding =
14261446 cast<DistributedEncodingTrait>(encoding).toLinearLayout (
14271447 tensorType.getShape ());
@@ -1440,14 +1460,21 @@ struct LoadOpConversion
14401460 Type eltTy = tensorType.getElementType ();
14411461 unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth ();
14421462
1443- auto tileParams = Subgroup2DBlockEncodingAttr::getInstrShapeForLayout (
1444- cast<DistributedEncodingTrait>(encoding), tensorType.getShape (),
1445- memoryRowMajor, elemSizeInBits / 8 , rewriter.getContext ());
1446- unsigned tileHeight = tileParams[0 ];
1447- const unsigned tileWidth = tileParams[1 ];
1448- const unsigned vBlocks = tileParams[2 ];
1463+ auto getTileParams = [&]() -> std::tuple<unsigned , unsigned , unsigned > {
1464+ if (hasSubgroup2DBlockEncoding (tensorType)) {
1465+ auto encoding =
1466+ cast<Subgroup2DBlockEncodingAttr>(tensorType.getEncoding ());
1467+ auto shape = encoding.getInstrShape ();
1468+ return std::make_tuple (shape[0 ], shape[1 ], encoding.getNumBlocks ());
1469+ } else {
1470+ auto tileParams = Subgroup2DBlockEncodingAttr::getInstrShapeForLayout (
1471+ cast<DistributedEncodingTrait>(encoding), tensorType.getShape (),
1472+ memoryRowMajor, elemSizeInBits / 8 , rewriter.getContext ());
1473+ return std::make_tuple (tileParams[0 ], tileParams[1 ], tileParams[2 ]);
1474+ }
1475+ };
1476+ auto [tileHeight, tileWidth, vBlocks] = getTileParams ();
14491477
1450- DpasEncodingAttr dpasLayout = getDpasLayout (tensorType);
14511478 const ArrayRef<int64_t > tensorShape = tensorType.getShape ();
14521479 unsigned numElems = getTotalElemsPerThread (resultType);
14531480 SmallVector<int64_t > numReps =
@@ -1617,6 +1644,7 @@ struct LoadOpConversion
16171644 // input operands to DPAS.
16181645 // TODO: add support for int4 and int2.
16191646 unsigned opsPerChannel = dpasLayout.getOpsPerChannel ();
1647+ llvm::errs () << " opsPerChannel = " << opsPerChannel << " \n " ;
16201648 if ((opsPerChannel == 4 && elemSizeInBits == 8 ) ||
16211649 (opsPerChannel == 2 && elemSizeInBits == 16 ) ||
16221650 (opsPerChannel == 1 && elemSizeInBits == 32 )) {
@@ -1840,6 +1868,8 @@ struct LoadOpConversion
18401868 unsigned numValuesPerLoad = packedElemsPerLanePerDPASInst *
18411869 numOperandsOuterDimPerLoad *
18421870 numOperandsInnerDimPerLoad;
1871+ llvm::errs () << " num values per load = " << numValuesPerLoad << " \n " ;
1872+ llvm::errs () << " loadResultElemType = " << loadResultElemType << " \n " ;
18431873 Type load2DGenXType =
18441874 LLVM::getVectorType (loadResultElemType, numValuesPerLoad);
18451875
@@ -2187,6 +2217,8 @@ struct LoadOpConversion
21872217 }
21882218
21892219 Type llvmResultStructTy = typeConverter->convertType (op.getType ());
2220+ llvm::errs () << " op.getType() " << op.getType () << " \n " ;
2221+ llvm::errs () << " llvmResultStructTy: " << llvmResultStructTy << " \n " ;
21902222 Value resultStruct = packLLElements (loc, typeConverter, unpackedLoadedVals,
21912223 rewriter, llvmResultStructTy);
21922224 rewriter.replaceOp (op, {resultStruct});
0 commit comments