Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions test/TritonIntelGPU/blockptr_load.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,67 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32}

// -----

// CHECK-DAG: llvm.func spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32> attributes {convergent, memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, no_unwind, will_return}
// CHECK-DAG: llvm.func spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x2cPU3AS1viiiDv2_iPt(!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {no_unwind, will_return}
// CHECK-DAG: llvm.func spir_funccc @_Z52intel_sub_group_2d_block_read_transform_16b_32r16x1cPU3AS1viiiDv2_iPj(!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {no_unwind, will_return}
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 4], order = [1, 0]}>
#dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
#dot0 = #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth=2}>
#dot1 = #ttg.dot_op<{opIdx = 1, parent = #dpas, kWidth=2}>
module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32} {
tt.func public @matmul_no_scf_with_add_kernel(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f32>, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64, %arg8: i64) {
%C = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #dpas>
%c0_i32 = arith.constant 0 : i32
%c1_i64 = arith.constant 1 : i64
%ptrA = tt.make_tensor_ptr %arg0, [%arg3, %arg5], [%arg6, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x32xf16, #dot0>>
%ptrB = tt.make_tensor_ptr %arg1, [%arg5, %arg4], [%arg8, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x64xf16, #dot1>>
// CHECK-COUNT-2: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x2cPU3AS1viiiDv2_iPt({{.*}}) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> ()
// CHECK-COUNT-2: llvm.call spir_funccc @_Z52intel_sub_group_2d_block_read_transform_16b_32r16x1cPU3AS1viiiDv2_iPj({{.*}}) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> ()
// CHECK-COUNT-8: llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f({{.*}}) {{.*}} : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32>
%A = tt.load %ptrA {boundaryCheck = array<i32: 1>, padding = 1 : i32, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<64x32xf16, #dot0>>
%B = tt.load %ptrB {boundaryCheck = array<i32: 0>, padding = 1 : i32, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<32x64xf16, #dot1>>
%D = tt.dot %A, %B, %C, inputPrecision = tf32 : tensor<64x32xf16, #dot0> * tensor<32x64xf16, #dot1> -> tensor<64x64xf32, #dpas>
%ptrX = tt.make_tensor_ptr %arg2, [%arg3, %arg4], [%arg8, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x64xf32, #dpas>>
// CHECK-COUNT-4: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_32b_8r16x1cPU3AS1viiiDv2_iPj({{.*}}) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> ()
%X = tt.load %ptrX {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<64x64xf32, #dpas>>
// CHECK-COUNT-32: llvm.fadd {{.*}}, {{.*}}
%0 = arith.addf %D, %X : tensor<64x64xf32, #dpas>
tt.return
}
}

// -----

// CHECK-DAG: llvm.func spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32> attributes {convergent, memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, no_unwind, will_return}
// CHECK-DAG: llvm.func spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x2cPU3AS1viiiDv2_iPt(!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {no_unwind, will_return}
// CHECK-DAG: llvm.func spir_funccc @_Z52intel_sub_group_2d_block_read_transform_16b_32r16x1cPU3AS1viiiDv2_iPj(!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {no_unwind, will_return}
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 4], order = [1, 0]}>
#dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
#dot0 = #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth=2}>
#dot1 = #ttg.dot_op<{opIdx = 1, parent = #dpas, kWidth=2}>
module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32} {
tt.func public @matmul_no_scf_with_add_transpose_kernel(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f32>, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64, %arg8: i64) {
%C = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #dpas>
%c0_i32 = arith.constant 0 : i32
%c1_i64 = arith.constant 1 : i64
%ptrA = tt.make_tensor_ptr %arg0, [%arg3, %arg5], [%arg6, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x32xf16, #dot0>>
%ptrB = tt.make_tensor_ptr %arg1, [%arg5, %arg4], [%arg8, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x64xf16, #dot1>>
// CHECK-COUNT-2: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x2cPU3AS1viiiDv2_iPt({{.*}}) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> ()
// CHECK-COUNT-2: llvm.call spir_funccc @_Z52intel_sub_group_2d_block_read_transform_16b_32r16x1cPU3AS1viiiDv2_iPj({{.*}}) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> ()
// CHECK-COUNT-8: llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f({{.*}}) {{.*}} : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32>
%A = tt.load %ptrA {boundaryCheck = array<i32: 1>, padding = 1 : i32, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<64x32xf16, #dot0>>
%B = tt.load %ptrB {boundaryCheck = array<i32: 0>, padding = 1 : i32, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<32x64xf16, #dot1>>
%D = tt.dot %A, %B, %C, inputPrecision = tf32 : tensor<64x32xf16, #dot0> * tensor<32x64xf16, #dot1> -> tensor<64x64xf32, #dpas>
%ptrX = tt.make_tensor_ptr %arg2, [%arg3, %arg4], [%arg8, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x64xf32, #dpas>>
// CHECK-NOT: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_32b_8r16x1cPU3AS1viiiDv2_iPj({{.*}}) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> ()
%X = tt.load %ptrX {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "column_major"} : !tt.ptr<tensor<64x64xf32, #dpas>>
%0 = arith.addf %D, %X : tensor<64x64xf32, #dpas>
tt.return
}
}

// -----

#dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [1, 1], A = [8, 8], B = [8, 16], C = [8, 16]}>
#dot0 = #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth=1}>
#dot1 = #ttg.dot_op<{opIdx = 1, parent = #dpas, kWidth=1}>
Expand Down
189 changes: 166 additions & 23 deletions third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,9 @@ struct LoadOpConversion
auto tensorType = cast<RankedTensorType>(resultType);

// Only lower loadOp with dpas layout encoding.
if (!hasDotDpasEncoding(tensorType))
auto encoding = tensorType.getEncoding();
const bool hasDpasLayout = isa<DpasEncodingAttr>(encoding);
if (!hasDpasLayout && !hasDotDpasEncoding(tensorType))
return failure();

Attribute blockIOAttr =
Expand All @@ -514,20 +516,24 @@ struct LoadOpConversion
"Only row_major or column_major is supported");
const bool memoryRowMajor = (memoryLayoutInfo == "row_major");

DotOperandEncodingAttr dotLayout = getDotEncoding(tensorType).value();
auto dotOrder = dotLayout.getThreadOrder();
size_t rank = dotOrder.size();
const bool valueRowMajor =
(dotOrder[rank - 2] == 1 && dotOrder[rank - 1] == 0);
assert((valueRowMajor ||
(dotOrder[rank - 2] == 0 && dotOrder[rank - 1] == 1)) &&
"Only row_major or column_major is allowed");
const bool isTransposeRequired = valueRowMajor ^ memoryRowMajor;

auto dpasLayout = cast<DpasEncodingAttr>(dotLayout.getParent());
auto getOpIdx = [&]() -> DpasEncodingAttr::OpIdx {
if (hasDpasLayout) {
return DpasEncodingAttr::OpIdx::OperandC;
} else {
auto dotLayout = getDotEncoding(tensorType).value();
return static_cast<DpasEncodingAttr::OpIdx>(dotLayout.getOpIdx());
}
};
auto opIdx = getOpIdx();

auto opIdx = static_cast<DpasEncodingAttr::OpIdx>(dotLayout.getOpIdx());
Type eltTy = tensorType.getElementType();
unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth();

auto dpasLayout = hasDpasLayout
? cast<DpasEncodingAttr>(encoding)
: cast<DpasEncodingAttr>(
getDotEncoding(tensorType).value().getParent());

const ArrayRef<int64_t> tensorShape = tensorType.getShape();
unsigned numElems = getTotalElemsPerThread(resultType);
SmallVector<int64_t> numReps =
Expand All @@ -543,6 +549,143 @@ struct LoadOpConversion
SmallVector<Value> multiDimWarpId =
delinearize(rewriter, loc, warpId, warpsPerCTA, dpasOrder);

if (hasDpasLayout) {
// A block load with the DPAS layout but without the DotDpasLayout is
// expected to follow the ordering of the DPAS output. For a 2D block
// load, the rows are distributed across work items/SIMD lanes and the
// column vectors are available for each work item to process. This layout
// aligns to the DPAS layout as the DPAS operation output layout
// distributes rows across work items.

size_t rank = dpasOrder.size();
const bool valueRowMajor =
(dpasOrder[rank - 2] == 1 && dpasOrder[rank - 1] == 0);
assert((valueRowMajor ||
(dpasOrder[rank - 2] == 0 && dpasOrder[rank - 1] == 1)) &&
"Only row_major or column_major is allowed");
const bool isTransposeRequired = valueRowMajor ^ memoryRowMajor;

if (isTransposeRequired) {
// TODO: this would likely require a shuffle to match the expected
// ordering coming out of the DPAS layout and requires more
// investigation
return failure();
}

MLIRContext *ctx = rewriter.getContext();

Value elemSizeInBytes = i32_val(elemSizeInBits / 8);

SmallVector<unsigned> elemsPerInstr = dpasLayout.getDPASInstShapeC();
int64_t elemsPerLane = product<unsigned>(elemsPerInstr) / threadsPerWarp;
Type load2DGenXType =
LLVM::getFixedVectorType(IntegerType::get(ctx, elemSizeInBits),
elemsPerLane); // make it opaque type.

auto [base, baseWidth, baseHeight, rowStride, colStride, offsetBaseX,
offsetBaseY] =
getValuesFromBlockPointerStruct(adaptor.getPtr(), rewriter);
baseWidth = trunc(i32_ty, baseWidth);
baseHeight = trunc(i32_ty, baseHeight);

auto pitch = trunc(i32_ty, rowStride);

SmallVector<unsigned> repClusterShape = dpasLayout.getShapeC();
unsigned outerDimWarpNum =
std::min<unsigned>(warpsPerCTA[rank - 2],
mlir::ceil<unsigned>(tensorShape[rank - 2],
repClusterShape[rank - 2]));
unsigned innerDimWarpNum =
std::min<unsigned>(warpsPerCTA[rank - 1],
mlir::ceil<unsigned>(tensorShape[rank - 1],
repClusterShape[rank - 1]));
Value outerDimWarpId =
urem(multiDimWarpId[rank - 2], i32_val(outerDimWarpNum));
Value innerDimWarpId =
urem(multiDimWarpId[rank - 1], i32_val(innerDimWarpNum));
int64_t numRepOuter = numReps[1];
int64_t numRepInner = numReps[2];

std::array<unsigned, 2> replicaStride = {
outerDimWarpNum * repClusterShape[rank - 2],
innerDimWarpNum * repClusterShape[rank - 1]};
std::array<unsigned, 2> warpStride = {repClusterShape[rank - 2],
repClusterShape[rank - 1]};

Value dimWarpId0 = mul(outerDimWarpId, i32_val(warpStride[0]));
Value dimWarpId1 = mul(innerDimWarpId, i32_val(warpStride[1]));
Value warpId0Offset = add(dimWarpId0, offsetBaseY);
Value warpId1Offset = add(dimWarpId1, offsetBaseX);

ArrayRef<unsigned> repCluster = dpasLayout.getRepCluster();
unsigned valOffset = 0;

SmallVector<Value> unpackedLoadedVals;

for (int m = 0; m < numRepOuter; ++m) {
for (int n = 0; n < numRepInner; ++n) {
for (int repM = 0; repM < repCluster[0]; ++repM) {

Value offsetY =
add(warpId0Offset,
i32_val(m * replicaStride[0] + repM * elemsPerInstr[0]));
for (int repN = 0; repN < repCluster[1]; ++repN) {
Value offsetX =
add(warpId1Offset,
i32_val(n * replicaStride[1] + repN * elemsPerInstr[1]));

auto load2dOp = rewriter.create<TritonGEN::Matrix2DBlockLoadOp>(
loc, load2DGenXType,
/*ptr*/ base,
/*base_width*/ mul(baseWidth, elemSizeInBytes),
/*base_height*/ baseHeight,
/*base_pitch*/ mul(pitch, elemSizeInBytes),
/*x*/ trunc(i32_ty, offsetX),
/*y*/ trunc(i32_ty, offsetY),
/*elem_size_in_bits*/ elemSizeInBits,
/*tile_width*/ elemsPerInstr[1],
/*tile_height*/ elemsPerInstr[0],
/*v_blocks*/ 1,
/*transpose*/ false,
/*vnni_transform*/ false);
if (failed(load2dOp.verify())) {
// Explicitly invoke verifier because `triton_gen` ops are
// immediately lowered further to a builtin call.
return failure();
}

Value ret = bitcast(
load2dOp, LLVM::getFixedVectorType(eltTy, elemsPerLane));

for (size_t i = 0; i < elemsPerLane; i++) {
Value loaded = extract_element(eltTy, ret, i32_val(i));
unpackedLoadedVals.push_back(loaded);
}
}
}
}
}

TritonGPUToLLVMTypeConverter *typeConverter = getTypeConverter();
Type llvmResultStructTy = typeConverter->convertType(op.getType());
Value resultStruct = packLLElements(
loc, typeConverter, unpackedLoadedVals, rewriter, llvmResultStructTy);
rewriter.replaceOp(op, {resultStruct});

return success();
}

DotOperandEncodingAttr dotLayout = getDotEncoding(tensorType).value();
auto dotOrder = dotLayout.getThreadOrder();

size_t rank = dotOrder.size();
const bool valueRowMajor =
(dotOrder[rank - 2] == 1 && dotOrder[rank - 1] == 0);
assert((valueRowMajor ||
(dotOrder[rank - 2] == 0 && dotOrder[rank - 1] == 1)) &&
"Only row_major or column_major is allowed");
const bool isTransposeRequired = valueRowMajor ^ memoryRowMajor;

bool isOperandA = (opIdx == DpasEncodingAttr::OpIdx::OperandA);
SmallVector<unsigned> dpasInstShape = isOperandA
? dpasLayout.getDPASInstShapeA()
Expand Down Expand Up @@ -573,11 +716,11 @@ struct LoadOpConversion
// input operands to DPAS.
// TODO: add support for int4 and int2.
unsigned opsPerChannel = dpasLayout.getOpsPerChannel();
unsigned elemBits = eltTy.getIntOrFloatBitWidth();
if ((opsPerChannel == 4 && elemBits == 8) ||
(opsPerChannel == 2 && elemBits == 16) ||
(opsPerChannel == 1 && elemBits == 32)) {
loadResultElemType = (isOperandA && elemBits != 32) ? i16_ty : i32_ty;
if ((opsPerChannel == 4 && elemSizeInBits == 8) ||
(opsPerChannel == 2 && elemSizeInBits == 16) ||
(opsPerChannel == 1 && elemSizeInBits == 32)) {
loadResultElemType =
(isOperandA && elemSizeInBits != 32) ? i16_ty : i32_ty;
packedElemsPerLanePerDPASInst =
isOperandA ? elemsPerLanePerDPASInst / (opsPerChannel == 4 ? 2 : 1)
: elemsPerLanePerDPASInst / opsPerChannel;
Expand Down Expand Up @@ -651,7 +794,7 @@ struct LoadOpConversion

// PVC 2D load supports 64 bytes per row at most. Load multiple dot operands
// by enlarging the vBlocks.
unsigned totalBytesPerRowPerDPASOp = tileWidth * elemBits / 8;
unsigned totalBytesPerRowPerDPASOp = tileWidth * elemSizeInBits / 8;
numOperandsPer2DloadN =
std::min(numOperandsPer2DloadN, 64 / totalBytesPerRowPerDPASOp);
vBlocks = numOperandsPer2DloadN;
Expand Down Expand Up @@ -695,12 +838,12 @@ struct LoadOpConversion
baseWidth = trunc(i32_ty, baseWidth);
baseHeight = trunc(i32_ty, baseHeight);

unsigned originalElemBits = elemBits;
const unsigned originalElemBits = elemSizeInBits;
if (isTransposeRequired) {
// adjust the block io parameter to align HW's limitations on
// transposing load.
tileWidth = tileWidth / (32 / originalElemBits);
elemBits = 32;
elemSizeInBits = 32;
}
Value elemSizeInBytes = i32_val(originalElemBits / 8);

Expand Down Expand Up @@ -747,14 +890,14 @@ struct LoadOpConversion
/*base_pitch*/ mul(pitch, elemSizeInBytes),
/*x*/ trunc(i32_ty, offsetX),
/*y*/ trunc(i32_ty, offsetY),
/*elem_size_in_bits*/ elemBits,
/*elem_size_in_bits*/ elemSizeInBits,
/*tile_width*/ tileWidth,
/*tile_height*/ tileHeight,
/*v_blocks*/ vBlocks,
/*transpose*/ isTransposeRequired,
/*vnni_transform*/
(usePackedType && !isOperandA && !isTransposeRequired &&
eltTy.getIntOrFloatBitWidth() != 32));
originalElemBits != 32));
if (failed(load2dOp.verify())) {
// Explicitly invoke verifier because `triton_gen` ops are
// immediately lowered further to a builtin call.
Expand Down