@@ -94,7 +94,8 @@ tile_load(tile_t& tile, payload_t& payload) {
9494 static constexpr gpu_arch arch_tag = payload_t ::arch_tag;
9595
9696 static constexpr reg_layout reg_layout_ = tile_desc::register_layout;
97- static constexpr bool is_vnni_reverse = payload_t ::mem_dword_transpose &&
97+ static constexpr bool is_vnni_reverse =
98+ payload_t ::mem_dword_qword_transpose &&
9899 ((reg_layout_ == reg_layout::tiled) ||
99100 (reg_layout_ == reg_layout::transpose_tiled));
100101 static constexpr bool reg_transpose = tile_desc::reg_transpose;
@@ -121,28 +122,27 @@ tile_load(tile_t& tile, payload_t& payload) {
121122 mem_transpose ? max_trans_block_width : max_load_block_height;
122123 static constexpr uint32_t ld_blk_size_y = reg_transpose
123124 ? block_size_y
124- : (block_size_y > ld_blk_size_y_limit ? ld_blk_size_y_limit
125- : block_size_y);
126-
127- // array len is used to make sure memory load is cache line aligned
128- // disabled while register or memory transpose
129- static constexpr uint8_t arr_len_candidate =
130- (reg_transpose ||
131- mem_transpose
132- // block elements should be integer
133- // times of register bytes
134- || ((block_size_y * block_size_x) % elems_per_reg != 0 )
135- // tail blocks also need to meet above condition
136- ||
137- (((tile_size_y % block_size_y) * block_size_x) % elems_per_reg != 0 )) ||
138- (block_size_y > ld_blk_size_y_limit)
139- ? 1
140- : (((tile_size_x % elems_per_CL) == 0 )
141- ? (((elems_per_CL % block_size_x) == 0 )
142- ? elems_per_CL / block_size_x
143- : 1 )
144- : ((tile_size_x < elems_per_CL) ? (tile_size_x / block_size_x)
145- : 1 ));
125+ : std::min (ld_blk_size_y_limit, block_size_y)
126+
127+ // array len is used to make sure memory load is cache line aligned
128+ // disabled while register or memory transpose
129+ static constexpr uint8_t arr_len_candidate =
130+ (reg_transpose ||
131+ mem_transpose
132+ // block elements should be integer
133+ // times of register bytes
134+ || ((block_size_y * block_size_x) % elems_per_reg != 0 )
135+ // tail blocks also need to meet above condition
136+ || (((tile_size_y % block_size_y) * block_size_x) % elems_per_reg !=
137+ 0 )) ||
138+ (block_size_y > ld_blk_size_y_limit)
139+ ? 1
140+ : (((tile_size_x % elems_per_CL) == 0 )
141+ ? (((elems_per_CL % block_size_x) == 0 )
142+ ? elems_per_CL / block_size_x
143+ : 1 )
144+ : ((tile_size_x < elems_per_CL) ? (tile_size_x / block_size_x)
145+ : 1 ));
146146 static constexpr bool is_valid_arr_len_candidate = (arr_len_candidate == 1 ) ||
147147 (arr_len_candidate == 2 ) || (arr_len_candidate == 4 );
148148
@@ -203,14 +203,31 @@ tile_load(tile_t& tile, payload_t& payload) {
203203 for (uint32_t ii = 0 ; ii < block_size_y / ld_blk_size_y; ++ii) {
204204 constexpr uint32_t load_elems = ld_blk_size_y * block_size_x * arr_len;
205205
206- reg_tmp.xetla_format <native_type_t <load_dtype>>() = xetla_tload_global<
206+ // reg_tmp.xetla_format<native_type_t<load_dtype>>() =
207+ // xetla_tload_global<
208+ // load_dtype,
209+ // ld_blk_height * block_size_x * arr_len / scale_factor,
210+ // L1,
211+ // L2,
212+ // trans,
213+ // mem_transform,
214+ // arch_tag>(tdesc);
215+ reg_tmp.xetla_format <native_type_t <load_dtype>>() = xetla_load_global<
207216 load_dtype,
208- ld_blk_height * block_size_x * arr_len / scale_factor,
209- L1 ,
210- L2 ,
217+ block_size_x / scale_factor,
218+ block_size_y ,
219+ num_block ,
211220 trans,
212221 mem_transform,
213- arch_tag>(tdesc);
222+ L1,
223+ L2>(
224+ (load_dtype*)::gpu::xetla::detail::xetla_get_tensor_base_address (
225+ tdesc),
226+ ::gpu::xetla::detail::xetla_get_tensor_width_x (tdesc),
227+ ::gpu::xetla::detail::xetla_get_tensor_width_y(tdesc),
228+ ::gpu::xetla::detail::xetla_get_tensor_pitch_x(tdesc),
229+ ::gpu::xetla::detail::xetla_get_tensor_offset_x(tdesc),
230+ ::gpu::xetla::detail::xetla_get_tensor_offset_y(tdesc));
214231 if constexpr (reg_transpose && trans) {
215232 reg_blk.xetla_select <load_elems, 1 >(ii * load_elems)
216233 .xetla_format <native_type_t <load_dtype>>() =
@@ -256,14 +273,30 @@ tile_load(tile_t& tile, payload_t& payload) {
256273 tdesc.xetla_format <uint32_t >(), block_widthx_widthy_arrlen);
257274
258275 reg_blk.xetla_select <load_elems, 1 >(remained_start)
259- .xetla_format <native_type_t <load_dtype>>() = xetla_tload_global <
276+ .xetla_format <native_type_t <load_dtype>>() = xetla_load_global <
260277 load_dtype,
261- (load_elems / scale_factor) ,
262- L1 ,
263- L2 ,
278+ block_size_x / scale_factor,
279+ block_size_y ,
280+ num_block ,
264281 trans,
265282 mem_transform,
266- arch_tag>(tdesc);
283+ L1,
284+ L2>(
285+ (load_dtype*)::gpu::xetla::detail::xetla_get_tensor_base_address (
286+ tdesc),
287+ ::gpu::xetla::detail::xetla_get_tensor_width_x (tdesc),
288+ ::gpu::xetla::detail::xetla_get_tensor_width_y(tdesc),
289+ ::gpu::xetla::detail::xetla_get_tensor_pitch_x(tdesc),
290+ ::gpu::xetla::detail::xetla_get_tensor_offset_x(tdesc),
291+ ::gpu::xetla::detail::xetla_get_tensor_offset_y(tdesc));
292+ // xetla_tload_global<
293+ // load_dtype,
294+ // (load_elems / scale_factor),
295+ // L1,
296+ // L2,
297+ // trans,
298+ // mem_transform,
299+ // arch_tag>(tdesc);
267300 }
268301 }
269302 }
@@ -301,14 +334,30 @@ tile_load(tile_t& tile, payload_t& payload) {
301334 constexpr uint32_t load_elems =
302335 remained_ld_blk_size_y * block_size_x * arr_len;
303336
304- reg_tmp.xetla_format <native_type_t <load_dtype>>() = xetla_tload_global <
337+ reg_tmp.xetla_format <native_type_t <load_dtype>>() = xetla_load_global <
305338 load_dtype,
306- (ld_blk_height * block_size_x * arr_len / scale_factor) ,
307- L1 ,
308- L2 ,
339+ block_size_x / scale_factor,
340+ block_size_y ,
341+ num_block ,
309342 trans,
310343 mem_transform,
311- arch_tag>(tdesc);
344+ L1,
345+ L2>(
346+ (load_dtype*)::gpu::xetla::detail::xetla_get_tensor_base_address (
347+ tdesc),
348+ ::gpu::xetla::detail::xetla_get_tensor_width_x (tdesc),
349+ ::gpu::xetla::detail::xetla_get_tensor_width_y(tdesc),
350+ ::gpu::xetla::detail::xetla_get_tensor_pitch_x(tdesc),
351+ ::gpu::xetla::detail::xetla_get_tensor_offset_x(tdesc),
352+ ::gpu::xetla::detail::xetla_get_tensor_offset_y(tdesc));
353+ // xetla_tload_global<
354+ // load_dtype,
355+ // (ld_blk_height * block_size_x * arr_len / scale_factor),
356+ // L1,
357+ // L2,
358+ // trans,
359+ // mem_transform,
360+ // arch_tag>(tdesc);
312361
313362 if constexpr (reg_transpose && trans) {
314363 reg_blk.xetla_select <load_elems, 1 >(ii * load_elems)
@@ -352,14 +401,30 @@ tile_load(tile_t& tile, payload_t& payload) {
352401 gpu::xetla::detail::xetla_set_block_widthx_widthy_arrlen (
353402 tdesc.xetla_format <uint32_t >(), block_widthx_widthy_arrlen);
354403 reg_blk.xetla_select <final_load_elems, 1 >(final_start)
355- .xetla_format <native_type_t <load_dtype>>() = xetla_tload_global <
404+ .xetla_format <native_type_t <load_dtype>>() = xetla_load_global <
356405 load_dtype,
357- final_load_elems / scale_factor,
358- L1 ,
359- L2 ,
406+ block_size_x / scale_factor,
407+ block_size_y ,
408+ num_block ,
360409 trans,
361410 mem_transform,
362- arch_tag>(tdesc);
411+ L1,
412+ L2>(
413+ (load_dtype*)::gpu::xetla::detail::xetla_get_tensor_base_address (
414+ tdesc),
415+ ::gpu::xetla::detail::xetla_get_tensor_width_x (tdesc),
416+ ::gpu::xetla::detail::xetla_get_tensor_width_y(tdesc),
417+ ::gpu::xetla::detail::xetla_get_tensor_pitch_x(tdesc),
418+ ::gpu::xetla::detail::xetla_get_tensor_offset_x(tdesc),
419+ ::gpu::xetla::detail::xetla_get_tensor_offset_y(tdesc));
420+ // xetla_tload_global<
421+ // load_dtype,
422+ // final_load_elems / scale_factor,
423+ // L1,
424+ // L2,
425+ // trans,
426+ // mem_transform,
427+ // arch_tag>(tdesc);
363428 }
364429 }
365430 }
0 commit comments