@@ -119,8 +119,9 @@ tile_load(tile_t& tile, payload_t& payload) {
119119 static constexpr uint32_t max_load_width_in_elem =
120120 load_store_attr::max_load_width_in_bytes / sizeof (dtype);
121121
122- // static constexpr uint32_t max_trans_load_height_in_elem =
123- // load_store_attr::max_trans_load_height_in_elem;
122+ static constexpr uint32_t max_trans_load_height_in_elem =
123+ load_store_attr::max_trans_load_height_in_elem;
124+
124125 static constexpr uint32_t max_load_height_in_elem =
125126 load_store_attr::max_load_height_in_elem;
126127
@@ -130,11 +131,25 @@ tile_load(tile_t& tile, payload_t& payload) {
130131 static constexpr uint32_t elems_per_reg =
131132 register_bytes_t <arch_tag>::reg_in_bytes / sizeof (dtype);
132133
134+ static constexpr uint32_t max_ld_blk_width_in_elem =
135+ trans ? max_trans_load_width_in_elem : max_load_width_in_elem;
136+
137+ static constexpr uint32_t max_ld_blk_height_in_elem =
138+ trans ? max_trans_load_height_in_elem : max_load_height_in_elem;
139+
140+ static constexpr uint32_t ld_blk_width = std::min (
141+ mem_transpose ? block_size_y : block_size_x, max_ld_blk_width_in_elem);
142+ static constexpr uint32_t ld_blk_height = std::min (
143+ mem_transpose ? block_size_x : block_size_y, max_ld_blk_height_in_elem);
144+
145+ static constexpr uint32_t ld_blk_size_y =
146+ mem_transpose ? ld_blk_width : ld_blk_height;
147+
133148 static constexpr uint32_t ld_blk_size_y_limit =
134149 mem_transpose ? max_trans_load_width_in_elem : max_load_height_in_elem;
135- static constexpr uint32_t ld_blk_size_y = reg_transpose
136- ? block_size_y
137- : std::min (ld_blk_size_y_limit, block_size_y);
150+ // static constexpr uint32_t ld_blk_size_y = reg_transpose
151+ // ? block_size_y
152+ // : std::min(ld_blk_size_y_limit, block_size_y);
138153
139154 // array len is used to make sure memory load is cache line aligned
140155 // disabled while register or memory transpose
@@ -198,10 +213,10 @@ tile_load(tile_t& tile, payload_t& payload) {
198213 constexpr uint32_t load_block_elems = block_elems * arr_len;
199214 auto reg_blk = tile.reg .xetla_select <load_block_elems, 1 >(
200215 (i * num_block_x + j) * block_elems);
201- constexpr uint32_t ld_blk_height = (reg_transpose && trans)
216+ constexpr uint32_t ld_blk_size_y_pad = (reg_transpose && trans)
202217 ? detail::getNextPowerOf2<ld_blk_size_y>()
203218 : ld_blk_size_y;
204- constexpr uint32_t tmp_size = ld_blk_height * block_size_x * arr_len;
219+ constexpr uint32_t tmp_size = ld_blk_width * ld_blk_height * arr_len;
205220 xetla_vector<dtype, tmp_size> reg_tmp;
206221#pragma unroll
207222 for (uint32_t ii = 0 ; ii < block_size_y / ld_blk_size_y; ++ii) {
@@ -213,10 +228,8 @@ tile_load(tile_t& tile, payload_t& payload) {
213228 mem_transpose ? offset_x : (offset_y + ii * ld_blk_size_y);
214229 reg_tmp.xetla_format <native_type_t <load_dtype>>() = xetla_load_global<
215230 native_type_t <load_dtype>,
216- (trans ? ld_blk_size_y : block_size_x) / scale_factor,
217- (trans ? block_size_x : ld_blk_size_y),
218- // block_size_x / scale_factor,
219- // ld_blk_size_y,
231+ ld_blk_width / scale_factor,
232+ ld_blk_height,
220233 arr_len,
221234 trans,
222235 mem_transform,
@@ -261,11 +274,6 @@ tile_load(tile_t& tile, payload_t& payload) {
261274 (mem_transpose ? remained_blk_size_y : block_size_x) / scale_factor;
262275 constexpr uint8_t block_height =
263276 mem_transpose ? block_size_x : remained_blk_size_y;
264- // constexpr uint32_t block_widthx_widthy_arrlen =
265- // (block_width - 1) | ((block_height - 1) << 8);
266- // gpu::xetla::detail::xetla_set_block_widthx_widthy_arrlen(
267- // tdesc.xetla_format<uint32_t>(), block_widthx_widthy_arrlen);
268-
269277 reg_blk.xetla_select <load_elems, 1 >(remained_start)
270278 .xetla_format <native_type_t <load_dtype>>() = xetla_load_global<
271279 native_type_t <load_dtype>,
@@ -283,15 +291,6 @@ tile_load(tile_t& tile, payload_t& payload) {
283291 payload.surface_pitch ,
284292 payload.offset_x + offset_x / scale_factor,
285293 payload.offset_y + offset_y + remained_start_y);
286-
287- // xetla_tload_global<
288- // load_dtype,
289- // (load_elems / scale_factor),
290- // L1,
291- // L2,
292- // trans,
293- // mem_transform,
294- // arch_tag>(tdesc);
295294 }
296295 }
297296 }
@@ -304,24 +303,16 @@ tile_load(tile_t& tile, payload_t& payload) {
304303 (!reg_transpose && (remained_size_y > ld_blk_size_y_limit))
305304 ? ld_blk_size_y_limit
306305 : remained_size_y;
307- // auto payload_row = payload_2d.xetla_select<num_block_x, 1, 16, 1>(
308- // num_block_y * num_block_x, 0);
309- // detail::reset_tile_desc_core<
310- // num_block_x,
311- // block_size_x,
312- // remained_ld_blk_size_y,
313- // scale_factor,
314- // arr_len,
315- // mem_transpose>(payload_row);
306+
316307#pragma unroll
317308 for (uint32_t j = 0 ; j < num_block_x; j += arr_len) {
318309 int32_t offset_x = j * block_size_x;
319310 // xetla_tdescriptor tdesc = payload_row.row(j);
320311 auto reg_blk = tile.reg .xetla_select <remained_block_elems * arr_len, 1 >(
321312 processed_elems + j * remained_block_elems);
322- constexpr uint32_t ld_blk_height = (reg_transpose && trans)
323- ? detail::getNextPowerOf2<remained_ld_blk_size_y>()
324- : remained_ld_blk_size_y;
313+ // constexpr uint32_t ld_blk_height = (reg_transpose && trans)
314+ // ? detail::getNextPowerOf2<remained_ld_blk_size_y>()
315+ // : remained_ld_blk_size_y;
325316 constexpr uint32_t tmp_size = ld_blk_height * block_size_x * arr_len;
326317 xetla_vector<dtype, tmp_size> reg_tmp;
327318#pragma unroll
0 commit comments