Skip to content

Commit 318ed76

Browse files
authored
[LoadStoreOpToLLVM] Transpose 2d load. (#4870)
To use the transpose 2d block io to load column major matrix from global memory. (The column major matrix here could be generalized to the cases that fast change dimension of register layout is not same as the fast change dim on global memory.) The 2d block io only can transpose the matrix of i32 type when load matrix from memory to register. To transpose the matrix of type bits number < 32, we need to further transpose matrix inside the register. The steps the to load matrix with transposing with 2D block IO: 1. To load the matrix as d32 type matrix from memory with transposed to register. 2. (Optional if scalar type < 32 bits) To transpose the `MxNxd32` to `Mx(32/m)xNxdm` inside the register. Right now we only use the bitcast operation to transpose the matrix of which the width is equal to the threads per warp for step 2. Further we will support more matrix of which the width is not equal to the threads per warp. Signed-off-by: Lu,Chengjun <chengjun.lu@intel.com>
1 parent 4e7fda2 commit 318ed76

File tree

3 files changed

+161
-658
lines changed

3 files changed

+161
-658
lines changed

python/test/unit/intel/test_block_io.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,9 @@ def warps_per_cta(layout):
120120
@pytest.mark.parametrize("layout", layouts)
121121
@pytest.mark.parametrize("load_block_ptr, store_block_ptr", [(True, True), (False, False), (True, False),
122122
(False, True)])
123+
@pytest.mark.parametrize("transpose", [True, False])
123124
@pytest.mark.skipif(not is_xpu(), reason="Block store tests are specific to the XPU backend")
124-
def test_block_io(M, N, dtype_str, layout, load_block_ptr, store_block_ptr, device, tmp_path: pathlib.Path):
125+
def test_block_io(M, N, dtype_str, layout, load_block_ptr, store_block_ptr, transpose, device, tmp_path: pathlib.Path):
125126

126127
warps = warps_per_cta(layout)
127128
num_warps = int(np.prod(warps))
@@ -132,16 +133,20 @@ def test_block_io(M, N, dtype_str, layout, load_block_ptr, store_block_ptr, devi
132133

133134
support_block_io = torch.xpu.get_device_capability()['has_subgroup_2d_block_io']
134135

136+
block_io = "\"column_major\"" if transpose else "\"row_major\""
137+
138+
strides = "[%c1_i64, %M_i64]" if transpose else "[%N_i64, %c1_i64]"
139+
135140
if load_block_ptr:
136141
load_ops = f"""
137-
%src_ptr = tt.make_tensor_ptr %src, [%M_i64, %N_i64], [%N_i64, %c1_i64], [%c0_i32, %c0_i32] {{order = array<i32: 1, 0>}} : <tensor<{M}x{N}x{ty}, #layout>>
138-
%store_val = tt.load %src_ptr {{ttig.block_io = "row_major", boundaryCheck = array<i32: 0, 1>, padding = 1 : i32}} : !tt.ptr<tensor<{M}x{N}x{ty}, #layout>>
142+
%src_ptr = tt.make_tensor_ptr %src, [%M_i64, %N_i64], {strides}, [%c0_i32, %c0_i32] {{order = array<i32: 1, 0>}} : <tensor<{M}x{N}x{ty}, #layout>>
143+
%store_val = tt.load %src_ptr {{ttig.block_io = {block_io}, boundaryCheck = array<i32: 0, 1>, padding = 1 : i32}} : !tt.ptr<tensor<{M}x{N}x{ty}, #layout>>
139144
"""
140145
else:
141146
load_ops = f"""
142147
%src_base = tt.splat %src : !tt.ptr<{ty}> -> tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>
143-
%src_ptr = tt.addptr %src_base, %row_major_off : tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>, tensor<{M}x{N}xi32, #layout>
144-
%store_val = tt.load %src_ptr {{ttig.block_io = "row_major"}} : tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>
148+
%src_ptr = tt.addptr %src_base, {"%col_major_off" if transpose else "%row_major_off" } : tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>, tensor<{M}x{N}xi32, #layout>
149+
%store_val = tt.load %src_ptr {{ttig.block_io = {block_io}}} : tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>
145150
"""
146151
if store_block_ptr:
147152
store_ops = f"""
@@ -175,6 +180,12 @@ def test_block_io(M, N, dtype_str, layout, load_block_ptr, store_block_ptr, devi
175180
%7 = tt.broadcast %5 : tensor<1x{N}xi32, #layout> -> tensor<{M}x{N}xi32, #layout>
176181
%row_major_off = arith.addi %6, %7 : tensor<{M}x{N}xi32, #layout>
177182
183+
%stride_M = arith.constant dense<{M}> : tensor<1x{N}xi32, #layout>
184+
%col_stride = arith.muli %5, %stride_M : tensor<1x{N}xi32, #layout>
185+
%8 = tt.broadcast %2 : tensor<{M}x1xi32, #layout> -> tensor<{M}x{N}xi32, #layout>
186+
%9 = tt.broadcast %col_stride : tensor<1x{N}xi32, #layout> -> tensor<{M}x{N}xi32, #layout>
187+
%col_major_off = arith.addi %8, %9 : tensor<{M}x{N}xi32, #layout>
188+
178189
{load_ops}
179190
{store_ops}
180191
@@ -195,10 +206,14 @@ def test_block_io(M, N, dtype_str, layout, load_block_ptr, store_block_ptr, devi
195206
temp_file.write_text(ir)
196207
kernel = triton.compile(str(temp_file))
197208

209+
a = a.permute(1, 0).contiguous().permute(1, 0) if transpose else a
210+
198211
kernel[(1, 1, 1)](a, x)
199212
assert torch.equal(a, x)
200213

201214
if support_block_io:
202215
if not load_block_ptr:
203-
assert 'spirv_Subgroup2DBlockLoad' in kernel.asm['llir'] or 'GenISA.LSC2DBlockRead' in kernel.asm['llir']
216+
if not ((transpose and type(layout) in [SliceLayout]) or
217+
(transpose and dtype_str in ["float16", "int8"])): # TODO: add support for these cases
218+
assert 'spirv_Subgroup2DBlockLoad' in kernel.asm['llir'] or 'GenISA.LSC2DBlockRead' in kernel.asm['llir']
204219
assert 'spirv_Subgroup2DBlockStoreINTEL' in kernel.asm['llir'] or 'GenISA.LSC2DBlockWrite' in kernel.asm['llir']

test/TritonIntelGPU/tensor-pointer-load-block-2d.mlir

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm | FileCheck %s --implicit-check-not=llvm.inline_asm
1+
// RUN: env TRITON_INTEL_ENABLE_BLOCK_IO_ALL_LAYOUTS=1 triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm | FileCheck %s --implicit-check-not=llvm.inline_asm
22

33
#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
44
module attributes {ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, ttig.support_dpas, ttig.support_sg_2d_block, ttig.target_arch = "spir64", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shared = 33280 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 16 : i32} {
@@ -566,3 +566,88 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 8 : i32, "ttg.thr
566566
tt.return
567567
}
568568
}
569+
570+
// -----
571+
572+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
573+
module attributes {ttg.global_scratch_memory_alignment = 1 : i32, ttg.global_scratch_memory_size = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 0 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 32 : i32, ttig.support_sg_2d_block} {
574+
tt.func public @trans_block_load_i32(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}) attributes {ttg.global_scratch_memory_alignment = 1 : i32, ttg.global_scratch_memory_size = 0 : i32} {
575+
%cst = arith.constant dense<64> : tensor<32x1xi32, #blocked>
576+
%0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
577+
%1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked>
578+
%3 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
579+
%4 = tt.expand_dims %3 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked>
580+
%cst_0 = arith.constant dense<32> : tensor<1x64xi32, #blocked>
581+
%8 = arith.muli %4, %cst_0 : tensor<1x64xi32, #blocked>
582+
%9 = tt.broadcast %1 : tensor<32x1xi32, #blocked> -> tensor<32x64xi32, #blocked>
583+
%10 = tt.broadcast %8 : tensor<1x64xi32, #blocked> -> tensor<32x64xi32, #blocked>
584+
%11 = arith.addi %9, %10 : tensor<32x64xi32, #blocked>
585+
%12 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<32x64x!tt.ptr<f32>, #blocked>
586+
%13 = tt.addptr %12, %11 : tensor<32x64x!tt.ptr<f32>, #blocked>, tensor<32x64xi32, #blocked>
587+
// COM: Transpose 2D block load with i32 type.
588+
// CHECK-COUNT-16: triton_gen.2Dblockload {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} {elem_size_in_bits = 32, tile_width = 2, tile_height = 16, v_blocks = 1, transpose = true, vnni_transform = false, cache_control = Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<1xi32>
589+
%14 = tt.load %13 {ttig.block_io = "column_major"} : tensor<32x64x!tt.ptr<f32>, #blocked>
590+
tt.return
591+
}
592+
}
593+
594+
// -----
595+
596+
#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 4, threadsPerWarp = 16, warpsPerCTA = [1, 4], repCluster = [1, 2], A = [8, 32], B = [32, 32], C = [8, 32]}>
597+
module attributes {ttg.global_scratch_memory_alignment = 1 : i32, ttg.global_scratch_memory_size = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 0 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 16 : i32, ttig.support_sg_2d_block} {
598+
tt.func public @trans_block_load_i16(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}) attributes {ttg.global_scratch_memory_alignment = 1 : i32, ttg.global_scratch_memory_size = 0 : i32} {
599+
%cst = arith.constant dense<64> : tensor<32x1xi32, #mma>
600+
%0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #mma}>>
601+
%1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<32x1xi32, #mma>
602+
%3 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>>
603+
%4 = tt.expand_dims %3 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<1x64xi32, #mma>
604+
%cst_0 = arith.constant dense<32> : tensor<1x64xi32, #mma>
605+
%8 = arith.muli %4, %cst_0 : tensor<1x64xi32, #mma>
606+
%9 = tt.broadcast %1 : tensor<32x1xi32, #mma> -> tensor<32x64xi32, #mma>
607+
%10 = tt.broadcast %8 : tensor<1x64xi32, #mma> -> tensor<32x64xi32, #mma>
608+
%11 = arith.addi %9, %10 : tensor<32x64xi32, #mma>
609+
%12 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<32x64x!tt.ptr<f16>, #mma>
610+
%13 = tt.addptr %12, %11 : tensor<32x64x!tt.ptr<f16>, #mma>, tensor<32x64xi32, #mma>
611+
// COM: Transpose 2D block load with f16 type. Pack the loaded vector to the i32 type. Then transpose the loaded i32 vector with bitcast op.
612+
// CHECK: %[[LOADED:.*]] = triton_gen.2Dblockload {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} {elem_size_in_bits = 32, tile_width = 8, tile_height = 16, v_blocks = 1, transpose = true, vnni_transform = false, cache_control = Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32>
613+
// CHECK: %[[PACKED_I32:.*]] = llvm.shufflevector %[[LOADED]], %[[LOADED]] [0, 1, 2, 3] : vector<8xi32>
614+
// CHECK: llvm.bitcast %[[PACKED_I32]] : vector<4xi32> to vector<8xf16>
615+
// CHECK-COUNT-3: triton_gen.2Dblockload {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} {elem_size_in_bits = 32, tile_width = 8, tile_height = 16, v_blocks = 1, transpose = true, vnni_transform = false, cache_control = Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32>
616+
%14 = tt.load %13 {ttig.block_io = "column_major"} : tensor<32x64x!tt.ptr<f16>, #mma>
617+
tt.return
618+
}
619+
}
620+
621+
// -----
622+
623+
#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 4, threadsPerWarp = 16, warpsPerCTA = [1, 4], repCluster = [1, 2], A = [8, 32], B = [32, 32], C = [8, 32]}>
624+
module attributes {ttg.global_scratch_memory_alignment = 1 : i32, ttg.global_scratch_memory_size = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 0 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 16 : i32, ttig.support_sg_2d_block} {
625+
tt.func public @trans_block_load_i8(%arg0: !tt.ptr<i8> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<i8> {tt.divisibility = 16 : i32}) attributes {ttg.global_scratch_memory_alignment = 1 : i32, ttg.global_scratch_memory_size = 0 : i32} {
626+
%cst = arith.constant dense<128> : tensor<128x1xi32, #mma>
627+
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma}>>
628+
%1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xi32, #mma>
629+
%2 = arith.muli %1, %cst : tensor<128x1xi32, #mma>
630+
%3 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #mma}>>
631+
%4 = tt.expand_dims %3 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<1x128xi32, #mma>
632+
%5 = tt.broadcast %2 : tensor<128x1xi32, #mma> -> tensor<128x128xi32, #mma>
633+
%6 = tt.broadcast %4 : tensor<1x128xi32, #mma> -> tensor<128x128xi32, #mma>
634+
%7 = arith.addi %5, %6 : tensor<128x128xi32, #mma>
635+
%cst_0 = arith.constant dense<128> : tensor<1x128xi32, #mma>
636+
%8 = arith.muli %4, %cst_0 : tensor<1x128xi32, #mma>
637+
%9 = tt.broadcast %1 : tensor<128x1xi32, #mma> -> tensor<128x128xi32, #mma>
638+
%10 = tt.broadcast %8 : tensor<1x128xi32, #mma> -> tensor<128x128xi32, #mma>
639+
%11 = arith.addi %9, %10 : tensor<128x128xi32, #mma>
640+
%12 = tt.splat %arg0 : !tt.ptr<i8> -> tensor<128x128x!tt.ptr<i8>, #mma>
641+
%13 = tt.addptr %12, %11 : tensor<128x128x!tt.ptr<i8>, #mma>, tensor<128x128xi32, #mma>
642+
// COM: Transpose 2D block load with i8 type. Pack the loaded vector to the i32 type. Then transpose the loaded i32 vector with bitcast op.
643+
// CHECK: %[[LOADED:.*]] = triton_gen.2Dblockload {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} {elem_size_in_bits = 32, tile_width = 8, tile_height = 16, v_blocks = 1, transpose = true, vnni_transform = false, cache_control = Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32>
644+
// COM: We do the shuffle and then the bitcast. Maybe it is efficient to do bitcast first then shuffle?
645+
// CHECK: %[[PACKED_1ST_HALF:.*]] = llvm.shufflevector %[[LOADED]], %[[LOADED]] [0, 1] : vector<8xi32>
646+
// CHECK: llvm.bitcast %[[PACKED_1ST_HALF]] : vector<2xi32> to vector<8xi8>
647+
// CHECK: %[[PACKED_2ND_HALF:.*]] = llvm.shufflevector %[[LOADED]], %[[LOADED]] [2, 3] : vector<8xi32>
648+
// CHECK: llvm.bitcast %[[PACKED_2ND_HALF]] : vector<2xi32> to vector<8xi8>
649+
// CHECK-COUNT-7: triton_gen.2Dblockload {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} {elem_size_in_bits = 32, tile_width = 8, tile_height = 16, v_blocks = 1, transpose = true, vnni_transform = false, cache_control = Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32>
650+
%14 = tt.load %13 {ttig.block_io = "column_major"} : tensor<128x128x!tt.ptr<i8>, #mma>
651+
tt.return
652+
}
653+
}

0 commit comments

Comments
 (0)