@@ -499,7 +499,9 @@ struct LoadOpConversion
499499 auto tensorType = cast<RankedTensorType>(resultType);
500500
501501 // Only lower loadOp with dpas layout encoding.
502- if (!hasDotDpasEncoding (tensorType))
502+ auto encoding = tensorType.getEncoding ();
503+ const bool hasDpasLayout = isa<DpasEncodingAttr>(encoding);
504+ if (!hasDpasLayout && !hasDotDpasEncoding (tensorType))
503505 return failure ();
504506
505507 Attribute blockIOAttr =
@@ -514,8 +516,11 @@ struct LoadOpConversion
514516 " Only row_major or column_major is supported" );
515517 const bool memoryRowMajor = (memoryLayoutInfo == " row_major" );
516518
517- DotOperandEncodingAttr dotLayout = getDotEncoding (tensorType).value ();
518- auto dotOrder = dotLayout.getThreadOrder ();
519+ auto dpasLayout = hasDpasLayout
520+ ? cast<DpasEncodingAttr>(encoding)
521+ : cast<DpasEncodingAttr>(
522+ getDotEncoding (tensorType).value ().getParent ());
523+ auto dotOrder = dpasLayout.getThreadOrder ();
519524 size_t rank = dotOrder.size ();
520525 const bool valueRowMajor =
521526 (dotOrder[rank - 2 ] == 1 && dotOrder[rank - 1 ] == 0 );
@@ -524,10 +529,19 @@ struct LoadOpConversion
524529 " Only row_major or column_major is allowed" );
525530 const bool isTransposeRequired = valueRowMajor ^ memoryRowMajor;
526531
527- auto dpasLayout = cast<DpasEncodingAttr>(dotLayout.getParent ());
532+ auto getOpIdx = [&]() -> DpasEncodingAttr::OpIdx {
533+ if (hasDpasLayout) {
534+ return DpasEncodingAttr::OpIdx::OperandC;
535+ } else {
536+ auto dotLayout = getDotEncoding (tensorType).value ();
537+ return static_cast <DpasEncodingAttr::OpIdx>(dotLayout.getOpIdx ());
538+ }
539+ };
528540
529- auto opIdx = static_cast <DpasEncodingAttr::OpIdx>(dotLayout. getOpIdx () );
541+ auto opIdx = getOpIdx ();
530542 Type eltTy = tensorType.getElementType ();
543+ unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth ();
544+
531545 const ArrayRef<int64_t > tensorShape = tensorType.getShape ();
532546 unsigned numElems = getTotalElemsPerThread (resultType);
533547 SmallVector<int64_t > numReps =
@@ -543,6 +557,123 @@ struct LoadOpConversion
543557 SmallVector<Value> multiDimWarpId =
544558 delinearize (rewriter, loc, warpId, warpsPerCTA, dpasOrder);
545559
560+ if (hasDpasLayout) {
561+ // A block load with the DPAS layout but without the DotDpasLayout is
562+ // expected to follow the ordering of the DPAS output. For a 2D block
563+ // load, the rows are distributed across work items/SIMD lanes and the
564+ // column vectors are available for each work item to process. This layout
565+ // aligns to the DPAS layout as the DPAS operation output layout
566+ // distributes rows across work items.
567+ if (isTransposeRequired) {
568+ // TODO: this would likely require a shuffle to match the expected
569+ // ordering coming out of the DPAS layout and requires more
570+ // investigation
571+ return failure ();
572+ }
573+
574+ MLIRContext *ctx = rewriter.getContext ();
575+
576+ Value elemSizeInBytes = i32_val (elemSizeInBits / 8 );
577+
578+ SmallVector<unsigned > elemsPerInstr = dpasLayout.getDPASInstShapeC ();
579+ int64_t elemsPerLane = product<unsigned >(elemsPerInstr) / threadsPerWarp;
580+ Type load2DGenXType =
581+ LLVM::getFixedVectorType (IntegerType::get (ctx, elemSizeInBits),
582+ elemsPerLane); // make it opaque type.
583+
584+ auto [base, baseWidth, baseHeight, rowStride, colStride, offsetBaseX,
585+ offsetBaseY] =
586+ getValuesFromBlockPointerStruct (adaptor.getPtr (), rewriter);
587+ baseWidth = trunc (i32_ty, baseWidth);
588+ baseHeight = trunc (i32_ty, baseHeight);
589+
590+ auto pitch = trunc (i32_ty, rowStride);
591+
592+ SmallVector<unsigned > repClusterShape = dpasLayout.getShapeC ();
593+ unsigned outerDimWarpNum =
594+ std::min<unsigned >(warpsPerCTA[rank - 2 ],
595+ mlir::ceil<unsigned >(tensorShape[rank - 2 ],
596+ repClusterShape[rank - 2 ]));
597+ unsigned innerDimWarpNum =
598+ std::min<unsigned >(warpsPerCTA[rank - 1 ],
599+ mlir::ceil<unsigned >(tensorShape[rank - 1 ],
600+ repClusterShape[rank - 1 ]));
601+ Value outerDimWarpId =
602+ urem (multiDimWarpId[rank - 2 ], i32_val (outerDimWarpNum));
603+ Value innerDimWarpId =
604+ urem (multiDimWarpId[rank - 1 ], i32_val (innerDimWarpNum));
605+ int64_t numRepOuter = numReps[1 ];
606+ int64_t numRepInner = numReps[2 ];
607+
608+ std::array<unsigned , 2 > replicaStride = {
609+ outerDimWarpNum * repClusterShape[rank - 2 ],
610+ innerDimWarpNum * repClusterShape[rank - 1 ]};
611+ std::array<unsigned , 2 > warpStride = {repClusterShape[rank - 2 ],
612+ repClusterShape[rank - 1 ]};
613+
614+ Value dimWarpId0 = mul (outerDimWarpId, i32_val (warpStride[0 ]));
615+ Value dimWarpId1 = mul (innerDimWarpId, i32_val (warpStride[1 ]));
616+ Value warpId0Offset = add (dimWarpId0, offsetBaseY);
617+ Value warpId1Offset = add (dimWarpId1, offsetBaseX);
618+
619+ ArrayRef<unsigned > repCluster = dpasLayout.getRepCluster ();
620+ unsigned valOffset = 0 ;
621+
622+ SmallVector<Value> unpackedLoadedVals;
623+
624+ for (int m = 0 ; m < numRepOuter; ++m) {
625+ for (int n = 0 ; n < numRepInner; ++n) {
626+ for (int repM = 0 ; repM < repCluster[0 ]; ++repM) {
627+
628+ Value offsetY =
629+ add (warpId0Offset,
630+ i32_val (m * replicaStride[0 ] + repM * elemsPerInstr[0 ]));
631+ for (int repN = 0 ; repN < repCluster[1 ]; ++repN) {
632+ Value offsetX =
633+ add (warpId1Offset,
634+ i32_val (n * replicaStride[1 ] + repN * elemsPerInstr[1 ]));
635+
636+ auto load2dOp = rewriter.create <TritonGEN::Matrix2DBlockLoadOp>(
637+ loc, load2DGenXType,
638+ /* ptr*/ base,
639+ /* base_width*/ mul (baseWidth, elemSizeInBytes),
640+ /* base_height*/ baseHeight,
641+ /* base_pitch*/ mul (pitch, elemSizeInBytes),
642+ /* x*/ trunc (i32_ty, offsetX),
643+ /* y*/ trunc (i32_ty, offsetY),
644+ /* elem_size_in_bits*/ elemSizeInBits,
645+ /* tile_width*/ elemsPerInstr[1 ],
646+ /* tile_height*/ elemsPerInstr[0 ],
647+ /* v_blocks*/ 1 ,
648+ /* transpose*/ false ,
649+ /* vnni_transform*/ false );
650+ if (failed (load2dOp.verify ())) {
651+ // Explicitly invoke verifier because `triton_gen` ops are
652+ // immediately lowered further to a builtin call.
653+ return failure ();
654+ }
655+
656+ Value ret = bitcast (
657+ load2dOp, LLVM::getFixedVectorType (eltTy, elemsPerLane));
658+
659+ for (size_t i = 0 ; i < elemsPerLane; i++) {
660+ Value loaded = extract_element (eltTy, ret, i32_val (i));
661+ unpackedLoadedVals.push_back (loaded);
662+ }
663+ }
664+ }
665+ }
666+ }
667+
668+ TritonGPUToLLVMTypeConverter *typeConverter = getTypeConverter ();
669+ Type llvmResultStructTy = typeConverter->convertType (op.getType ());
670+ Value resultStruct = packLLElements (
671+ loc, typeConverter, unpackedLoadedVals, rewriter, llvmResultStructTy);
672+ rewriter.replaceOp (op, {resultStruct});
673+
674+ return success ();
675+ }
676+
546677 bool isOperandA = (opIdx == DpasEncodingAttr::OpIdx::OperandA);
547678 SmallVector<unsigned > dpasInstShape = isOperandA
548679 ? dpasLayout.getDPASInstShapeA ()
@@ -573,11 +704,11 @@ struct LoadOpConversion
573704 // input operands to DPAS.
574705 // TODO: add support for int4 and int2.
575706 unsigned opsPerChannel = dpasLayout.getOpsPerChannel ();
576- unsigned elemBits = eltTy. getIntOrFloatBitWidth ();
577- if (( opsPerChannel == 4 && elemBits == 8 ) ||
578- (opsPerChannel == 2 && elemBits == 16 ) ||
579- (opsPerChannel == 1 && elemBits == 32 )) {
580- loadResultElemType = (isOperandA && elemBits != 32 ) ? i16_ty : i32_ty;
707+ if ((opsPerChannel == 4 && elemSizeInBits == 8 ) ||
708+ ( opsPerChannel == 2 && elemSizeInBits == 16 ) ||
709+ (opsPerChannel == 1 && elemSizeInBits == 32 )) {
710+ loadResultElemType =
711+ (isOperandA && elemSizeInBits != 32 ) ? i16_ty : i32_ty;
581712 packedElemsPerLanePerDPASInst =
582713 isOperandA ? elemsPerLanePerDPASInst / (opsPerChannel == 4 ? 2 : 1 )
583714 : elemsPerLanePerDPASInst / opsPerChannel;
@@ -651,7 +782,7 @@ struct LoadOpConversion
651782
652783 // PVC 2D load supports 64 bytes per row at most. Load multiple dot operands
653784 // by enlarging the vBlocks.
654- unsigned totalBytesPerRowPerDPASOp = tileWidth * elemBits / 8 ;
785+ unsigned totalBytesPerRowPerDPASOp = tileWidth * elemSizeInBits / 8 ;
655786 numOperandsPer2DloadN =
656787 std::min (numOperandsPer2DloadN, 64 / totalBytesPerRowPerDPASOp);
657788 vBlocks = numOperandsPer2DloadN;
@@ -695,12 +826,12 @@ struct LoadOpConversion
695826 baseWidth = trunc (i32_ty, baseWidth);
696827 baseHeight = trunc (i32_ty, baseHeight);
697828
698- unsigned originalElemBits = elemBits ;
829+ const unsigned originalElemBits = elemSizeInBits ;
699830 if (isTransposeRequired) {
700831 // adjust the block io parameter to align HW's limitations on
701832 // transposing load.
702833 tileWidth = tileWidth / (32 / originalElemBits);
703- elemBits = 32 ;
834+ elemSizeInBits = 32 ;
704835 }
705836 Value elemSizeInBytes = i32_val (originalElemBits / 8 );
706837
@@ -747,14 +878,14 @@ struct LoadOpConversion
747878 /* base_pitch*/ mul (pitch, elemSizeInBytes),
748879 /* x*/ trunc (i32_ty, offsetX),
749880 /* y*/ trunc (i32_ty, offsetY),
750- /* elem_size_in_bits*/ elemBits ,
881+ /* elem_size_in_bits*/ elemSizeInBits ,
751882 /* tile_width*/ tileWidth,
752883 /* tile_height*/ tileHeight,
753884 /* v_blocks*/ vBlocks,
754885 /* transpose*/ isTransposeRequired,
755886 /* vnni_transform*/
756887 (usePackedType && !isOperandA && !isTransposeRequired &&
757- eltTy. getIntOrFloatBitWidth () != 32 ));
888+ originalElemBits != 32 ));
758889 if (failed (load2dOp.verify ())) {
759890 // Explicitly invoke verifier because `triton_gen` ops are
760891 // immediately lowered further to a builtin call.
0 commit comments