Skip to content

Commit 5b829bc

Browse files
committed
Use block loads for post-dpas vector computation
1 parent 9ec46fe commit 5b829bc

File tree

2 files changed

+177
-15
lines changed

2 files changed

+177
-15
lines changed

test/TritonIntelGPU/blockptr_load.mlir

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,37 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32}
2828

2929
// -----
3030

31+
// CHECK-DAG: llvm.func spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32> attributes {convergent, memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, no_unwind, will_return}
32+
// CHECK-DAG: llvm.func spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x2cPU3AS1viiiDv2_iPt(!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {no_unwind, will_return}
33+
// CHECK-DAG: llvm.func spir_funccc @_Z52intel_sub_group_2d_block_read_transform_16b_32r16x1cPU3AS1viiiDv2_iPj(!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {no_unwind, will_return}
34+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 4], order = [1, 0]}>
35+
#dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
36+
#dot0 = #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth=2}>
37+
#dot1 = #ttg.dot_op<{opIdx = 1, parent = #dpas, kWidth=2}>
38+
module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32} {
39+
tt.func public @matmul_no_scf_with_add_kernel(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f32>, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64, %arg8: i64) {
40+
%C = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #dpas>
41+
%c0_i32 = arith.constant 0 : i32
42+
%c1_i64 = arith.constant 1 : i64
43+
%ptrA = tt.make_tensor_ptr %arg0, [%arg3, %arg5], [%arg6, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x32xf16, #dot0>>
44+
%ptrB = tt.make_tensor_ptr %arg1, [%arg5, %arg4], [%arg8, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x64xf16, #dot1>>
45+
// CHECK-COUNT-2: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x2cPU3AS1viiiDv2_iPt({{.*}}) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> ()
46+
// CHECK-COUNT-2: llvm.call spir_funccc @_Z52intel_sub_group_2d_block_read_transform_16b_32r16x1cPU3AS1viiiDv2_iPj({{.*}}) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> ()
47+
// CHECK-COUNT-8: llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f({{.*}}) {{.*}} : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32>
48+
%A = tt.load %ptrA {boundaryCheck = array<i32: 1>, padding = 1 : i32, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<64x32xf16, #dot0>>
49+
%B = tt.load %ptrB {boundaryCheck = array<i32: 0>, padding = 1 : i32, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<32x64xf16, #dot1>>
50+
%D = tt.dot %A, %B, %C, inputPrecision = tf32 : tensor<64x32xf16, #dot0> * tensor<32x64xf16, #dot1> -> tensor<64x64xf32, #dpas>
51+
%ptrX = tt.make_tensor_ptr %arg2, [%arg3, %arg4], [%arg8, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x64xf32, #dpas>>
52+
// CHECK-COUNT-4: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_32b_8r16x1cPU3AS1viiiDv2_iPj({{.*}}) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> ()
53+
%X = tt.load %ptrX {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<64x64xf32, #dpas>>
54+
// CHECK-COUNT-32: llvm.fadd {{.*}}, {{.*}}
55+
%0 = arith.addf %D, %X : tensor<64x64xf32, #dpas>
56+
tt.return
57+
}
58+
}
59+
60+
// -----
61+
3162
#dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [1, 1], A = [8, 8], B = [8, 16], C = [8, 16]}>
3263
#dot0 = #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth=1}>
3364
#dot1 = #ttg.dot_op<{opIdx = 1, parent = #dpas, kWidth=1}>

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 146 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,9 @@ struct LoadOpConversion
499499
auto tensorType = cast<RankedTensorType>(resultType);
500500

501501
// Only lower loadOp with dpas layout encoding.
502-
if (!hasDotDpasEncoding(tensorType))
502+
auto encoding = tensorType.getEncoding();
503+
const bool hasDpasLayout = isa<DpasEncodingAttr>(encoding);
504+
if (!hasDpasLayout && !hasDotDpasEncoding(tensorType))
503505
return failure();
504506

505507
Attribute blockIOAttr =
@@ -514,8 +516,11 @@ struct LoadOpConversion
514516
"Only row_major or column_major is supported");
515517
const bool memoryRowMajor = (memoryLayoutInfo == "row_major");
516518

517-
DotOperandEncodingAttr dotLayout = getDotEncoding(tensorType).value();
518-
auto dotOrder = dotLayout.getThreadOrder();
519+
auto dpasLayout = hasDpasLayout
520+
? cast<DpasEncodingAttr>(encoding)
521+
: cast<DpasEncodingAttr>(
522+
getDotEncoding(tensorType).value().getParent());
523+
auto dotOrder = dpasLayout.getThreadOrder();
519524
size_t rank = dotOrder.size();
520525
const bool valueRowMajor =
521526
(dotOrder[rank - 2] == 1 && dotOrder[rank - 1] == 0);
@@ -524,10 +529,19 @@ struct LoadOpConversion
524529
"Only row_major or column_major is allowed");
525530
const bool isTransposeRequired = valueRowMajor ^ memoryRowMajor;
526531

527-
auto dpasLayout = cast<DpasEncodingAttr>(dotLayout.getParent());
532+
auto getOpIdx = [&]() -> DpasEncodingAttr::OpIdx {
533+
if (hasDpasLayout) {
534+
return DpasEncodingAttr::OpIdx::OperandC;
535+
} else {
536+
auto dotLayout = getDotEncoding(tensorType).value();
537+
return static_cast<DpasEncodingAttr::OpIdx>(dotLayout.getOpIdx());
538+
}
539+
};
528540

529-
auto opIdx = static_cast<DpasEncodingAttr::OpIdx>(dotLayout.getOpIdx());
541+
auto opIdx = getOpIdx();
530542
Type eltTy = tensorType.getElementType();
543+
unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth();
544+
531545
const ArrayRef<int64_t> tensorShape = tensorType.getShape();
532546
unsigned numElems = getTotalElemsPerThread(resultType);
533547
SmallVector<int64_t> numReps =
@@ -543,6 +557,123 @@ struct LoadOpConversion
543557
SmallVector<Value> multiDimWarpId =
544558
delinearize(rewriter, loc, warpId, warpsPerCTA, dpasOrder);
545559

560+
if (hasDpasLayout) {
561+
// A block load with the DPAS layout but without the DotDpasLayout is
562+
// expected to follow the ordering of the DPAS output. For a 2D block
563+
// load, the rows are distributed across work items/SIMD lanes and the
564+
// column vectors are available for each work item to process. This layout
565+
// aligns to the DPAS layout as the DPAS operation output layout
566+
// distributes rows across work items.
567+
if (isTransposeRequired) {
568+
// TODO: this would likely require a shuffle to match the expected
569+
// ordering coming out of the DPAS layout and requires more
570+
// investigation
571+
return failure();
572+
}
573+
574+
MLIRContext *ctx = rewriter.getContext();
575+
576+
Value elemSizeInBytes = i32_val(elemSizeInBits / 8);
577+
578+
SmallVector<unsigned> elemsPerInstr = dpasLayout.getDPASInstShapeC();
579+
int64_t elemsPerLane = product<unsigned>(elemsPerInstr) / threadsPerWarp;
580+
Type load2DGenXType =
581+
LLVM::getFixedVectorType(IntegerType::get(ctx, elemSizeInBits),
582+
elemsPerLane); // make it opaque type.
583+
584+
auto [base, baseWidth, baseHeight, rowStride, colStride, offsetBaseX,
585+
offsetBaseY] =
586+
getValuesFromBlockPointerStruct(adaptor.getPtr(), rewriter);
587+
baseWidth = trunc(i32_ty, baseWidth);
588+
baseHeight = trunc(i32_ty, baseHeight);
589+
590+
auto pitch = trunc(i32_ty, rowStride);
591+
592+
SmallVector<unsigned> repClusterShape = dpasLayout.getShapeC();
593+
unsigned outerDimWarpNum =
594+
std::min<unsigned>(warpsPerCTA[rank - 2],
595+
mlir::ceil<unsigned>(tensorShape[rank - 2],
596+
repClusterShape[rank - 2]));
597+
unsigned innerDimWarpNum =
598+
std::min<unsigned>(warpsPerCTA[rank - 1],
599+
mlir::ceil<unsigned>(tensorShape[rank - 1],
600+
repClusterShape[rank - 1]));
601+
Value outerDimWarpId =
602+
urem(multiDimWarpId[rank - 2], i32_val(outerDimWarpNum));
603+
Value innerDimWarpId =
604+
urem(multiDimWarpId[rank - 1], i32_val(innerDimWarpNum));
605+
int64_t numRepOuter = numReps[1];
606+
int64_t numRepInner = numReps[2];
607+
608+
std::array<unsigned, 2> replicaStride = {
609+
outerDimWarpNum * repClusterShape[rank - 2],
610+
innerDimWarpNum * repClusterShape[rank - 1]};
611+
std::array<unsigned, 2> warpStride = {repClusterShape[rank - 2],
612+
repClusterShape[rank - 1]};
613+
614+
Value dimWarpId0 = mul(outerDimWarpId, i32_val(warpStride[0]));
615+
Value dimWarpId1 = mul(innerDimWarpId, i32_val(warpStride[1]));
616+
Value warpId0Offset = add(dimWarpId0, offsetBaseY);
617+
Value warpId1Offset = add(dimWarpId1, offsetBaseX);
618+
619+
ArrayRef<unsigned> repCluster = dpasLayout.getRepCluster();
620+
unsigned valOffset = 0;
621+
622+
SmallVector<Value> unpackedLoadedVals;
623+
624+
for (int m = 0; m < numRepOuter; ++m) {
625+
for (int n = 0; n < numRepInner; ++n) {
626+
for (int repM = 0; repM < repCluster[0]; ++repM) {
627+
628+
Value offsetY =
629+
add(warpId0Offset,
630+
i32_val(m * replicaStride[0] + repM * elemsPerInstr[0]));
631+
for (int repN = 0; repN < repCluster[1]; ++repN) {
632+
Value offsetX =
633+
add(warpId1Offset,
634+
i32_val(n * replicaStride[1] + repN * elemsPerInstr[1]));
635+
636+
auto load2dOp = rewriter.create<TritonGEN::Matrix2DBlockLoadOp>(
637+
loc, load2DGenXType,
638+
/*ptr*/ base,
639+
/*base_width*/ mul(baseWidth, elemSizeInBytes),
640+
/*base_height*/ baseHeight,
641+
/*base_pitch*/ mul(pitch, elemSizeInBytes),
642+
/*x*/ trunc(i32_ty, offsetX),
643+
/*y*/ trunc(i32_ty, offsetY),
644+
/*elem_size_in_bits*/ elemSizeInBits,
645+
/*tile_width*/ elemsPerInstr[1],
646+
/*tile_height*/ elemsPerInstr[0],
647+
/*v_blocks*/ 1,
648+
/*transpose*/ false,
649+
/*vnni_transform*/ false);
650+
if (failed(load2dOp.verify())) {
651+
// Explicitly invoke verifier because `triton_gen` ops are
652+
// immediately lowered further to a builtin call.
653+
return failure();
654+
}
655+
656+
Value ret = bitcast(
657+
load2dOp, LLVM::getFixedVectorType(eltTy, elemsPerLane));
658+
659+
for (size_t i = 0; i < elemsPerLane; i++) {
660+
Value loaded = extract_element(eltTy, ret, i32_val(i));
661+
unpackedLoadedVals.push_back(loaded);
662+
}
663+
}
664+
}
665+
}
666+
}
667+
668+
TritonGPUToLLVMTypeConverter *typeConverter = getTypeConverter();
669+
Type llvmResultStructTy = typeConverter->convertType(op.getType());
670+
Value resultStruct = packLLElements(
671+
loc, typeConverter, unpackedLoadedVals, rewriter, llvmResultStructTy);
672+
rewriter.replaceOp(op, {resultStruct});
673+
674+
return success();
675+
}
676+
546677
bool isOperandA = (opIdx == DpasEncodingAttr::OpIdx::OperandA);
547678
SmallVector<unsigned> dpasInstShape = isOperandA
548679
? dpasLayout.getDPASInstShapeA()
@@ -573,11 +704,11 @@ struct LoadOpConversion
573704
// input operands to DPAS.
574705
// TODO: add support for int4 and int2.
575706
unsigned opsPerChannel = dpasLayout.getOpsPerChannel();
576-
unsigned elemBits = eltTy.getIntOrFloatBitWidth();
577-
if ((opsPerChannel == 4 && elemBits == 8) ||
578-
(opsPerChannel == 2 && elemBits == 16) ||
579-
(opsPerChannel == 1 && elemBits == 32)) {
580-
loadResultElemType = (isOperandA && elemBits != 32) ? i16_ty : i32_ty;
707+
if ((opsPerChannel == 4 && elemSizeInBits == 8) ||
708+
(opsPerChannel == 2 && elemSizeInBits == 16) ||
709+
(opsPerChannel == 1 && elemSizeInBits == 32)) {
710+
loadResultElemType =
711+
(isOperandA && elemSizeInBits != 32) ? i16_ty : i32_ty;
581712
packedElemsPerLanePerDPASInst =
582713
isOperandA ? elemsPerLanePerDPASInst / (opsPerChannel == 4 ? 2 : 1)
583714
: elemsPerLanePerDPASInst / opsPerChannel;
@@ -651,7 +782,7 @@ struct LoadOpConversion
651782

652783
// PVC 2D load supports 64 bytes per row at most. Load multiple dot operands
653784
// by enlarging the vBlocks.
654-
unsigned totalBytesPerRowPerDPASOp = tileWidth * elemBits / 8;
785+
unsigned totalBytesPerRowPerDPASOp = tileWidth * elemSizeInBits / 8;
655786
numOperandsPer2DloadN =
656787
std::min(numOperandsPer2DloadN, 64 / totalBytesPerRowPerDPASOp);
657788
vBlocks = numOperandsPer2DloadN;
@@ -695,12 +826,12 @@ struct LoadOpConversion
695826
baseWidth = trunc(i32_ty, baseWidth);
696827
baseHeight = trunc(i32_ty, baseHeight);
697828

698-
unsigned originalElemBits = elemBits;
829+
const unsigned originalElemBits = elemSizeInBits;
699830
if (isTransposeRequired) {
700831
// adjust the block io parameter to align HW's limitations on
701832
// transposing load.
702833
tileWidth = tileWidth / (32 / originalElemBits);
703-
elemBits = 32;
834+
elemSizeInBits = 32;
704835
}
705836
Value elemSizeInBytes = i32_val(originalElemBits / 8);
706837

@@ -747,14 +878,14 @@ struct LoadOpConversion
747878
/*base_pitch*/ mul(pitch, elemSizeInBytes),
748879
/*x*/ trunc(i32_ty, offsetX),
749880
/*y*/ trunc(i32_ty, offsetY),
750-
/*elem_size_in_bits*/ elemBits,
881+
/*elem_size_in_bits*/ elemSizeInBits,
751882
/*tile_width*/ tileWidth,
752883
/*tile_height*/ tileHeight,
753884
/*v_blocks*/ vBlocks,
754885
/*transpose*/ isTransposeRequired,
755886
/*vnni_transform*/
756887
(usePackedType && !isOperandA && !isTransposeRequired &&
757-
eltTy.getIntOrFloatBitWidth() != 32));
888+
originalElemBits != 32));
758889
if (failed(load2dOp.verify())) {
759890
// Explicitly invoke verifier because `triton_gen` ops are
760891
// immediately lowered further to a builtin call.

0 commit comments

Comments
 (0)