@@ -180,6 +180,7 @@ tile_load(tile_t& tile, payload_t& payload) {
180180#pragma unroll
181181 for (uint32_t i = 0 ; i < num_block_y; ++i) {
182182 constexpr uint32_t load_block_elems = block_elems * arr_len;
183+ int offset_y = i * block_size_y;
183184 auto payload_row =
184185 payload_2d.xetla_select <num_block_x, 1 , 16 , 1 >(i * num_block_x, 0 );
185186 detail::reset_tile_desc_core<
@@ -191,6 +192,7 @@ tile_load(tile_t& tile, payload_t& payload) {
191192 mem_transpose>(payload_row);
192193#pragma unroll
193194 for (uint32_t j = 0 ; j < num_block_x; j += arr_len) {
195+ uint32_t offset_x = j * block_size_x;
194196 xetla_tdescriptor tdesc = payload_row.row (j);
195197 auto reg_blk = tile.reg .xetla_select <load_block_elems, 1 >(
196198 (i * num_block_x + j) * block_elems);
@@ -201,6 +203,7 @@ tile_load(tile_t& tile, payload_t& payload) {
201203 xetla_vector<dtype, tmp_size> reg_tmp;
202204#pragma unroll
203205 for (uint32_t ii = 0 ; ii < block_size_y / ld_blk_size_y; ++ii) {
206+ // offset_y += ld_blk_size_y;
204207 constexpr uint32_t load_elems = ld_blk_size_y * block_size_x * arr_len;
205208 reg_tmp.xetla_format <native_type_t <load_dtype>>() = xetla_load_global<
206209 native_type_t <load_dtype>,
@@ -217,16 +220,12 @@ tile_load(tile_t& tile, payload_t& payload) {
217220 payload.surface_width ,
218221 payload.surface_height ,
219222 payload.surface_pitch ,
220- // payload.offset_x,
221- // payload.offset_y);
222-
223- // (native_type_t<load_dtype>*)::gpu::xetla::detail::
224- // xetla_get_tensor_base_address(tdesc),
225- // ::gpu::xetla::detail::xetla_get_tensor_width_x(tdesc),
226- // ::gpu::xetla::detail::xetla_get_tensor_width_y(tdesc),
227- // ::gpu::xetla::detail::xetla_get_tensor_pitch_x(tdesc),
228- ::gpu::xetla::detail::xetla_get_tensor_offset_x (tdesc),
229- ::gpu::xetla::detail::xetla_get_tensor_offset_y(tdesc));
223+ mem_transpose
224+ // ? (payload.offset_x + offset_y / scale_factor)
225+ ? ::gpu::xetla::detail::xetla_get_tensor_offset_x (tdesc)
226+ : (payload.offset_x + offset_x / scale_factor),
227+
228+ payload.offset_y + (mem_transpose ? offset_x : offset_y));
230229 if constexpr (reg_transpose && trans) {
231230 reg_blk.xetla_select <load_elems, 1 >(ii * load_elems)
232231 .xetla_format <native_type_t <load_dtype>>() =
@@ -243,7 +242,6 @@ tile_load(tile_t& tile, payload_t& payload) {
243242 } else {
244243 reg_blk.xetla_select <tmp_size, 1 >(ii * tmp_size) = reg_tmp;
245244 }
246-
247245 if constexpr (mem_transpose) {
248246 xetla_update_tdesc_offsetx (
249247 tdesc.xetla_format <uint32_t >(), ld_blk_size_y / scale_factor);
0 commit comments