Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit 9bab674

Browse files
committed
opt load_xe
1 parent f738dc7 commit 9bab674

File tree

3 files changed

+53
-49
lines changed

3 files changed

+53
-49
lines changed

include/subgroup/tile/impl/load_xe.hpp

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -106,19 +106,31 @@ tile_load(tile_t& tile, payload_t& payload) {
106106
static constexpr bool mem_transform = payload_t::mem_transform;
107107

108108
using load_store_attr = load_store_attr_t<msg_type::block_2d, arch_tag>;
109+
110+
// static constexpr uint32_t max_load_width_in_elem = trans
111+
// ? load_store_attr::max_trans_load_width_in_bytes / sizeof(dtype)
112+
// : load_store_attr::max_load_width_in_bytes / sizeof(dtype);
113+
// static constexpr uint32_t max_load_height_in_elem = trans
114+
// ? load_store_attr::max_trans_load_height_in_elem
115+
// : load_store_attr::max_load_height_in_elem;
116+
static constexpr uint32_t max_trans_load_width_in_elem =
117+
load_store_attr::max_trans_load_width_in_bytes / sizeof(dtype);
118+
static constexpr uint32_t max_load_width_in_elem =
119+
load_store_attr::max_load_width_in_bytes / sizeof(dtype);
120+
121+
// static constexpr uint32_t max_trans_load_height_in_elem =
122+
// load_store_attr::max_trans_load_height_in_elem;
123+
static constexpr uint32_t max_load_height_in_elem =
124+
load_store_attr::max_load_height_in_elem;
125+
109126
static constexpr uint32_t elems_per_CL =
110127
load_store_attr::cache_line_size_in_bytes / sizeof(dtype);
128+
111129
static constexpr uint32_t elems_per_reg =
112130
register_bytes_t<arch_tag>::reg_in_bytes / sizeof(dtype);
113-
static constexpr int32_t max_load_block_height =
114-
load_store_attr::max_load_height_in_elem;
115-
static constexpr int32_t max_block_width =
116-
load_store_attr::max_load_width_in_bytes / sizeof(dtype);
117-
static constexpr int32_t max_trans_block_width =
118-
load_store_attr::max_trans_load_width_in_bytes / sizeof(dtype);
119131

120132
static constexpr uint32_t ld_blk_size_y_limit =
121-
mem_transpose ? max_trans_block_width : max_load_block_height;
133+
mem_transpose ? max_trans_load_width_in_elem : max_load_height_in_elem;
122134
static constexpr uint32_t ld_blk_size_y = reg_transpose
123135
? block_size_y
124136
: std::min(ld_blk_size_y_limit, block_size_y);
@@ -150,20 +162,21 @@ tile_load(tile_t& tile, payload_t& payload) {
150162

151163
static_assert(
152164
reg_transpose || mem_transpose ||
153-
(!mem_transpose && (block_size_x * arr_len) <= max_block_width),
165+
(!mem_transpose &&
166+
(block_size_x * arr_len) <= max_load_width_in_elem),
154167
"When reg_transpose was disabled, check 2d block width "
155168
"restriction");
156169
static_assert(
157170
!reg_transpose ||
158171
(!mem_transpose &&
159-
(block_size_x * arr_len) <= max_trans_block_width) ||
160-
(mem_transpose && (block_size_y * arr_len) <= max_block_width),
172+
(block_size_x * arr_len) <= max_trans_load_width_in_elem) ||
173+
(mem_transpose && (block_size_y * arr_len) <= max_load_width_in_elem),
161174
"When reg_transpose was enabled, check 2d block width "
162175
"restriction");
163176
static_assert(
164177
!reg_transpose ||
165-
(!mem_transpose && (block_size_y <= max_load_block_height)) ||
166-
(mem_transpose && (block_size_x) <= max_load_block_height),
178+
(!mem_transpose && (block_size_y <= max_load_height_in_elem)) ||
179+
(mem_transpose && (block_size_x) <= max_load_height_in_elem),
167180
"When reg_transpose was enabled, check 2d block height "
168181
"restriction");
169182
static_assert(

include/subgroup/tile/impl/payload_xe.hpp

Lines changed: 10 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -85,26 +85,6 @@ struct mem_payload_t<
8585
conditional_t<mem_transpose_dtype_less4bytes, uint32_t, dtype>;
8686
static constexpr uint32_t scale_factor = sizeof(mem_dtype) / sizeof(dtype);
8787

88-
using load_store_attr = load_store_attr_t<msg_type::block_2d, arch_tag>;
89-
90-
static constexpr uint32_t max_load_width_in_elem = trans
91-
? load_store_attr::max_trans_load_width_in_bytes / sizeof(dtype)
92-
: load_store_attr::max_load_width_in_bytes / sizeof(dtype);
93-
static constexpr uint32_t max_load_height_in_elem = trans
94-
? load_store_attr::max_trans_load_height_in_elem
95-
: load_store_attr::max_load_height_in_elem;
96-
97-
static constexpr uint32_t max_store_width_in_elem =
98-
load_store_attr::max_store_width_in_bytes / sizeof(dtype);
99-
static constexpr uint32_t max_store_height_in_elem =
100-
load_store_attr::max_store_height_in_elem;
101-
102-
static constexpr uint32_t elems_per_CL =
103-
load_store_attr::cache_line_size_in_bytes / sizeof(dtype);
104-
105-
static constexpr uint32_t elems_per_reg =
106-
register_bytes_t<arch_tag>::reg_in_bytes / sizeof(dtype);
107-
10888
dtype* base_ptr;
10989
uint32_t surface_width;
11090
uint32_t surface_height;
@@ -1732,11 +1712,12 @@ struct prefetch_payload_t<
17321712
reg_layout_>,
17331713
num_coop_sg_,
17341714
arch_tag_,
1735-
std::enable_if_t<(!arch_has_2d_load_store<arch_tag_>)&&(
1736-
((block_size_y_ != 1 || tile_size_y_ != 1) &&
1737-
mem_layout_ == mem_layout::row_major) ||
1738-
((block_size_x_ != 1 || tile_size_x_ != 1) &&
1739-
mem_layout_ == mem_layout::col_major))>> {
1715+
std::enable_if_t<
1716+
(!arch_has_2d_load_store<arch_tag_>) &&
1717+
(((block_size_y_ != 1 || tile_size_y_ != 1) &&
1718+
mem_layout_ == mem_layout::row_major) ||
1719+
((block_size_x_ != 1 || tile_size_x_ != 1) &&
1720+
mem_layout_ == mem_layout::col_major))>> {
17401721
using dtype = native_type_t<dtype_>;
17411722
using mem_desc_t =
17421723
mem_desc_t<dtype_, mem_layout_, mem_space::global, alignment_>;
@@ -1992,9 +1973,10 @@ struct prefetch_payload_t<
19921973
reg_layout_>,
19931974
num_coop_sg_,
19941975
arch_tag_,
1995-
std::enable_if_t<(arch_has_2d_load_store<arch_tag_>)&&(
1996-
((tile_size_y_ != 1) && mem_layout_ == mem_layout::row_major) ||
1997-
((tile_size_x_ != 1) && mem_layout_ == mem_layout::col_major))>> {
1976+
std::enable_if_t<
1977+
(arch_has_2d_load_store<arch_tag_>) &&
1978+
(((tile_size_y_ != 1) && mem_layout_ == mem_layout::row_major) ||
1979+
((tile_size_x_ != 1) && mem_layout_ == mem_layout::col_major))>> {
19981980
using dtype = dtype_;
19991981
using mem_desc_t =
20001982
mem_desc_t<dtype_, mem_layout_, mem_space::global, alignment_>;

include/subgroup/tile/impl/store_xe.hpp

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -99,19 +99,28 @@ tile_store(tile_t& tile, payload_t& payload) {
9999
static constexpr uint32_t num_block_x = tile_desc::num_block_x;
100100
static constexpr uint32_t num_block_y = tile_desc::num_block_y;
101101

102+
static constexpr gpu_arch arch_tag = payload_t::arch_tag;
103+
104+
using load_store_attr = load_store_attr_t<msg_type::block_2d, arch_tag>;
105+
static constexpr uint32_t max_store_width_in_elem =
106+
load_store_attr::max_store_width_in_bytes / sizeof(dtype);
107+
static constexpr uint32_t max_store_height_in_elem =
108+
load_store_attr::max_store_height_in_elem;
109+
110+
static constexpr uint32_t elems_per_CL =
111+
load_store_attr::cache_line_size_in_bytes / sizeof(dtype);
112+
102113
static_assert(
103-
(payload_t::max_store_width_in_elem % block_size_x) == 0,
114+
(max_store_width_in_elem % block_size_x) == 0,
104115
"max_store_width_in_elem should be a multiply of block_size_x.");
105116

106117
static constexpr uint32_t st_blk_size_y =
107-
std::min(block_size_y, payload_t::max_store_height_in_elem);
118+
std::min(block_size_y, max_store_height_in_elem);
108119

109120
// to make sure full CL store
110-
static constexpr uint32_t st_blk_size_x =
111-
((tile_size_x % payload_t::elems_per_CL) == 0)
112-
? payload_t::elems_per_CL
113-
: (((payload_t::elems_per_CL % tile_size_x) == 0) ? tile_size_x
114-
: block_size_x);
121+
static constexpr uint32_t st_blk_size_x = ((tile_size_x % elems_per_CL) == 0)
122+
? elems_per_CL
123+
: (((elems_per_CL % tile_size_x) == 0) ? tile_size_x : block_size_x);
115124

116125
static constexpr uint8_t arr_len_candidate = st_blk_size_x / block_size_x;
117126
static constexpr bool is_valid_arr_len_candidate = (arr_len_candidate == 1) ||
@@ -120,14 +129,13 @@ tile_store(tile_t& tile, payload_t& payload) {
120129
static constexpr uint8_t arr_len =
121130
is_valid_arr_len_candidate ? arr_len_candidate : 1;
122131

123-
constexpr uint32_t store_block_elems = block_elems * arr_len;
124-
constexpr uint32_t store_elems = st_blk_size_y * st_blk_size_x;
125132
#pragma unroll
126133
for (uint32_t i = 0; i < num_block_y; ++i) {
127134
int32_t offset_y = i * block_size_y;
128135
#pragma unroll
129136
for (uint32_t j = 0; j < num_block_x; j += arr_len) {
130137
int32_t offset_x = j * block_size_x;
138+
constexpr uint32_t store_block_elems = block_elems * arr_len;
131139
auto reg_blk = tile.reg.xetla_select<store_block_elems, 1>(
132140
(i * num_block_x + j) * block_elems);
133141
xetla_vector<dtype, store_block_elems> combine_blk;
@@ -150,6 +158,7 @@ tile_store(tile_t& tile, payload_t& payload) {
150158
}
151159
#pragma unroll
152160
for (uint32_t ii = 0; ii < block_size_y; ii += st_blk_size_y) {
161+
constexpr uint32_t store_elems = st_blk_size_y * st_blk_size_x;
153162
auto st_blk =
154163
combine_blk.xetla_select<store_elems, 1>(ii * st_blk_size_x);
155164
xetla_store_global<dtype, st_blk_size_x, st_blk_size_y, L1, L2>(

0 commit comments

Comments
 (0)