@@ -2521,18 +2521,6 @@ struct LoadOpToBlockIOConversion
25212521 if (tileHeight * tileWidth * packedElemSizeInBits / 8 < GRF_SIZE)
25222522 vBlocks = 1 ;
25232523
2524- // TODO: use the axis info to general the handling for both regular pointer
2525- // and block pointer.
2526- const bool memoryRowMajor = isMemoryRowMajor (op);
2527- // FIXME: Add support of column major.
2528- if (!memoryRowMajor)
2529- return failure ();
2530-
2531- unsigned contiguousDim = memoryRowMajor ? 1 : 0 ;
2532- const bool isTransposeRequired = contiguousDim != colDim;
2533- if (isTransposeRequired)
2534- return matchAndRewriteTranspose (op, adaptor, rewriter);
2535-
25362524 Location loc = op.getLoc ();
25372525 auto b = TritonLLVMOpBuilder (loc, rewriter);
25382526 MLIRContext *ctx = rewriter.getContext ();
@@ -2661,6 +2649,55 @@ struct LoadOpToBlockIOConversion
26612649 }
26622650 }
26632651
2652+ // TODO: use the axis info to general the handling for both regular pointer
2653+ // and block pointer.
2654+ const bool memoryRowMajor = isMemoryRowMajor (op);
2655+ unsigned contiguousDim = memoryRowMajor ? 1 : 0 ;
2656+ const bool isTransposeRequired = contiguousDim != colDim;
2657+
2658+ if (isTransposeRequired) {
2659+ if (numPackedVals > 1 )
2660+ return failure ();
2661+ if (elemSizeInBits > 32 )
2662+ return failure ();
2663+ if (tileWidth > 32 )
2664+ return failure (); // tileWidth is limited to 32 for transpose 2d load.
2665+
2666+ vBlocks = 1 ;
2667+
2668+ // use the d32 for transpose 2d load.
2669+ packedElemSizeInBits = 32 ;
2670+ numPackedVals = packedElemSizeInBits / elemSizeInBits;
2671+ if (numPackedVals > 1 && tileWidth != threadsPerWarp)
2672+ return failure (); // Couldn't use the transpose 2d load for un-packable
2673+ // along tile height dim.
2674+ tileHeight = std::min (tileHeight / numPackedVals, 8 );
2675+
2676+ if (tileHeight * tileWidth < threadsPerWarp)
2677+ return failure (); // The tile size is not large enough for IGC scalar
2678+ // backend vectorization.
2679+ // transpose the width and height of the tile
2680+ std::swap (tileHeight, tileWidth);
2681+ // if (oneMatrixPerLoadForBT) {
2682+ // // Only load 1 operand per inst on row.
2683+ // numOperandsPer2DLoadM = 1;
2684+ // tileHeight = elemsPerDPASInst[threadOrder[rank - 2]];
2685+ // } else {
2686+ // // We can decompose the matrix returned by transposed large 2d load
2687+ // // when threads per warp < column size. Otherwise we have to load one
2688+ // // operand per inst.
2689+ // // Note: the tileHeight and numOperandsPer2DLoadM are the column size
2690+ // // now.
2691+ // numOperandsPer2DLoadM =
2692+ // (threadsPerWarp <= tileHeight) ? repCluster[rank - 1] : 1;
2693+ // }
2694+ // // The transpose 2d load only support 1 operand per inst on column.
2695+ // // (vBlocks = 1)
2696+ // numOperandsPer2DloadN = 1;
2697+ // // TODO: support load column major data.
2698+ // return failure();
2699+ }
2700+
26642701 int64_t numElemsPerLoad = mlir::ceil (
26652702 tileHeight * tileWidth * numPackedVals * vBlocks, (int )threadsPerWarp);
26662703 unsigned numValuesPerLoad = mlir::ceil ((int )numElemsPerLoad, numPackedVals);
@@ -2740,8 +2777,6 @@ struct LoadOpToBlockIOConversion
27402777 }
27412778 } break ;
27422779 case DpasEncodingAttr::OpIdx::OperandB: {
2743- assert (numPackedVals == 1 &&
2744- " invalid number of packed values for DPAS operand B." );
27452780 unsigned elemsPerLanePerDPASInst =
27462781 product<unsigned >(dpasLayout.getDPASInstShapeB ()) / threadsPerWarp;
27472782 // Block 2D contain at least one DotOp B.
@@ -2751,6 +2786,9 @@ struct LoadOpToBlockIOConversion
27512786 if (tileHeight >= (opsPerChannel * sysDepth) &&
27522787 ((opsPerChannel == 4 && elemSizeInBits == 8 ) ||
27532788 (opsPerChannel == 2 && elemSizeInBits == 16 ))) {
2789+ assert (!isTransposeRequired ||
2790+ opsPerChannel == numPackedVals &&
2791+ " invalid opsPerChannel for transposed DotOp B" );
27542792 // Use the VNNI packing format for DotOp B layout.
27552793 numValuesPerLoad = numElemsPerLoad / opsPerChannel;
27562794 packedType = i32_ty;
@@ -2814,8 +2852,8 @@ struct LoadOpToBlockIOConversion
28142852 /* tile_width*/ tileWidth,
28152853 /* tile_height*/ tileHeight,
28162854 /* v_blocks*/ vBlocks,
2817- /* transpose*/ false ,
2818- /* vnni_transform*/ useVNNIFormat);
2855+ /* transpose*/ isTransposeRequired ,
2856+ /* vnni_transform*/ !isTransposeRequired && useVNNIFormat);
28192857
28202858 // When strides[0] is 0, we only want to load the first row, so we
28212859 // set the base height to be 1. If tile height is bigger than 1,
0 commit comments