Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit 71c2c1d

Browse files
committed
replace base_ptr/width/height/surface_pitch
1 parent a0f3194 commit 71c2c1d

File tree

2 files changed

+14
-17
lines changed

2 files changed

+14
-17
lines changed

include/subgroup/tile/impl/load_xe.hpp

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ tile_load(tile_t& tile, payload_t& payload) {
180180
#pragma unroll
181181
for (uint32_t i = 0; i < num_block_y; ++i) {
182182
constexpr uint32_t load_block_elems = block_elems * arr_len;
183+
int offset_y = i * block_size_y;
183184
auto payload_row =
184185
payload_2d.xetla_select<num_block_x, 1, 16, 1>(i * num_block_x, 0);
185186
detail::reset_tile_desc_core<
@@ -191,6 +192,7 @@ tile_load(tile_t& tile, payload_t& payload) {
191192
mem_transpose>(payload_row);
192193
#pragma unroll
193194
for (uint32_t j = 0; j < num_block_x; j += arr_len) {
195+
uint32_t offset_x = j * block_size_x;
194196
xetla_tdescriptor tdesc = payload_row.row(j);
195197
auto reg_blk = tile.reg.xetla_select<load_block_elems, 1>(
196198
(i * num_block_x + j) * block_elems);
@@ -201,6 +203,7 @@ tile_load(tile_t& tile, payload_t& payload) {
201203
xetla_vector<dtype, tmp_size> reg_tmp;
202204
#pragma unroll
203205
for (uint32_t ii = 0; ii < block_size_y / ld_blk_size_y; ++ii) {
206+
// offset_y += ld_blk_size_y;
204207
constexpr uint32_t load_elems = ld_blk_size_y * block_size_x * arr_len;
205208
reg_tmp.xetla_format<native_type_t<load_dtype>>() = xetla_load_global<
206209
native_type_t<load_dtype>,
@@ -217,16 +220,12 @@ tile_load(tile_t& tile, payload_t& payload) {
217220
payload.surface_width,
218221
payload.surface_height,
219222
payload.surface_pitch,
220-
// payload.offset_x,
221-
// payload.offset_y);
222-
223-
// (native_type_t<load_dtype>*)::gpu::xetla::detail::
224-
// xetla_get_tensor_base_address(tdesc),
225-
// ::gpu::xetla::detail::xetla_get_tensor_width_x(tdesc),
226-
// ::gpu::xetla::detail::xetla_get_tensor_width_y(tdesc),
227-
// ::gpu::xetla::detail::xetla_get_tensor_pitch_x(tdesc),
228-
::gpu::xetla::detail::xetla_get_tensor_offset_x(tdesc),
229-
::gpu::xetla::detail::xetla_get_tensor_offset_y(tdesc));
223+
mem_transpose
224+
// ? (payload.offset_x + offset_y / scale_factor)
225+
? ::gpu::xetla::detail::xetla_get_tensor_offset_x(tdesc)
226+
: (payload.offset_x + offset_x / scale_factor),
227+
228+
payload.offset_y + (mem_transpose ? offset_x : offset_y));
230229
if constexpr (reg_transpose && trans) {
231230
reg_blk.xetla_select<load_elems, 1>(ii * load_elems)
232231
.xetla_format<native_type_t<load_dtype>>() =
@@ -243,7 +242,6 @@ tile_load(tile_t& tile, payload_t& payload) {
243242
} else {
244243
reg_blk.xetla_select<tmp_size, 1>(ii * tmp_size) = reg_tmp;
245244
}
246-
247245
if constexpr (mem_transpose) {
248246
xetla_update_tdesc_offsetx(
249247
tdesc.xetla_format<uint32_t>(), ld_blk_size_y / scale_factor);

include/subgroup/tile/impl/payload_xe.hpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -111,10 +111,8 @@ struct mem_payload_t<
111111
this->surface_height =
112112
(mem_transpose ? mem_desc.shape.x : mem_desc.shape.y);
113113
this->surface_pitch = mem_desc.shape.stride * sizeof(dtype);
114-
// this->offset_x = mem_desc.coord.x;
115-
// this->offset_y = mem_desc.coord.y;
116-
this->offset_x = mem_transpose ? mem_desc.coord.y : mem_desc.coord.x;
117-
this->offset_x = this->offset_x / scale_factor;
114+
this->offset_x =
115+
(mem_transpose ? mem_desc.coord.y : mem_desc.coord.x) / scale_factor;
118116
this->offset_y = mem_transpose ? mem_desc.coord.x : mem_desc.coord.y;
119117

120118
xetla_tdescriptor base_tdesc = mem_desc.get_tdesc();
@@ -159,8 +157,9 @@ struct mem_payload_t<
159157
this->surface_height =
160158
(mem_transpose ? mem_desc.shape.x : mem_desc.shape.y);
161159
this->surface_pitch = mem_desc.shape.stride * sizeof(dtype);
162-
this->offset_x = mem_desc.coord.x / scale_factor;
163-
this->offset_y = mem_desc.coord.y;
160+
this->offset_x =
161+
(mem_transpose ? mem_desc.coord.y : mem_desc.coord.x) / scale_factor;
162+
this->offset_y = (mem_transpose ? mem_desc.coord.x : mem_desc.coord.y);
164163

165164
xetla_tdescriptor base_tdesc = mem_desc.get_tdesc();
166165
int32_t offset = gpu::xetla::detail::xetla_get_tensor_offset_x(base_tdesc) /

0 commit comments

Comments
 (0)