@@ -79,7 +79,7 @@ template <
7979 typename payload_t >
8080__XETLA_API typename std::enable_if_t <
8181 detail::check_load_type<tile_t , payload_t >::is_global_block_2d &&
82- payload_t ::arch_tag == gpu_arch::XeHpc >
82+ arch_has_2d_load_store ( payload_t ::arch_tag) >
8383tile_load (tile_t & tile, payload_t & payload) {
8484 using dtype = typename tile_t ::dtype;
8585 using load_dtype = typename payload_t ::mem_dtype;
@@ -405,23 +405,37 @@ tile_load(tile_t& tile, payload_t& payload) {
405405
406406 static constexpr uint32_t tile_size_x = tile_t ::tile_size_x;
407407 static constexpr uint32_t scale_factor = payload_t ::scale_factor;
408- constexpr uint32_t load_len = tile_size_x / scale_factor;
408+ static constexpr uint32_t load_len = tile_size_x / scale_factor;
409+ static constexpr gpu_arch arch_tag = payload_t ::arch_tag;
410+ using load_store_attr = load_store_attr_t <msg_type::block_1d, arch_tag>;
411+ static constexpr uint32_t max_load_vec_len =
412+ load_store_attr::max_load_vec_len;
409413
410- if constexpr (load_len >= 64 ) {
414+ static constexpr uint32_t load_iter_steps = load_len / max_load_vec_len;
415+ if constexpr (load_len >= max_load_vec_len) {
411416#pragma unroll
412- for (uint32_t i = 0 ; i < load_len / 64 ; i++) {
413- uint32_t offset_x = i * 64 * scale_factor;
414- auto reg_sub = tile.reg .xetla_select <64 * scale_factor, 1 >(offset_x);
417+ for (uint32_t i = 0 ; i < load_iter_steps; i++) {
418+ uint32_t offset_x = i * max_load_vec_len * scale_factor;
419+ auto reg_sub =
420+ tile.reg .xetla_select <max_load_vec_len * scale_factor, 1 >(offset_x);
415421 uint32_t address_offset = offset_x * sizeof (dtype);
416- reg_sub.xetla_format <load_dtype>() =
417- xetla_load_global<load_dtype, 64 , data_size::default_size, L1, L2>(
418- payload.base_ptr , payload.base_offset + address_offset);
422+ reg_sub.xetla_format <load_dtype>() = xetla_load_global<
423+ load_dtype,
424+ max_load_vec_len,
425+ data_size::default_size,
426+ L1,
427+ L2>(payload.base_ptr , payload.base_offset + address_offset);
419428 }
420429 }
421- constexpr uint32_t tail_len = load_len % 64 ;
422- uint32_t tail_offset = load_len / 64 * 64 * scale_factor;
423- detail::process_1d_tail<tail_len, 32 , detail::process_flag::load, L1, L2>(
424- tile, payload, tail_offset);
430+
431+ constexpr uint32_t tail_len = load_len % max_load_vec_len;
432+ uint32_t tail_offset = load_iter_steps * max_load_vec_len * scale_factor;
433+ detail::process_1d_tail<
434+ tail_len,
435+ (max_load_vec_len >> 1 ),
436+ detail::process_flag::load,
437+ L1,
438+ L2>(tile, payload, tail_offset);
425439}
426440
427441// / @brief This function loads data from unaligned-2D memory surface.
@@ -850,21 +864,33 @@ tile_load(tile_t& tile, payload_t& payload) {
850864 using load_dtype = typename payload_t ::mem_dtype;
851865
852866 constexpr uint32_t scale_factor = payload_t ::scale_factor;
853- constexpr uint32_t load_len = tile_desc::tile_size_x / scale_factor;
854- if constexpr (load_len >= 64 ) {
867+ static constexpr uint32_t load_len = tile_desc::tile_size_x / scale_factor;
868+ static constexpr gpu_arch arch_tag = payload_t ::arch_tag;
869+ using load_store_attr = load_store_attr_t <msg_type::block_1d, arch_tag>;
870+ static constexpr uint32_t max_load_vec_len =
871+ load_store_attr::max_load_vec_len;
872+
873+ static constexpr uint32_t load_iter_steps = load_len / max_load_vec_len;
874+
875+ if constexpr (load_len >= max_load_vec_len) {
855876#pragma unroll
856- for (uint32_t j = 0 ; j < load_len / 64 ; j++) {
857- uint32_t offset_x = j * 64 * scale_factor;
858- auto reg_sub = tile.reg .xetla_select <64 * scale_factor, 1 >(offset_x);
877+ for (uint32_t j = 0 ; j < load_iter_steps; j++) {
878+ uint32_t offset_x = j * max_load_vec_len * scale_factor;
879+ auto reg_sub =
880+ tile.reg .xetla_select <max_load_vec_len * scale_factor, 1 >(offset_x);
859881 uint32_t address_offset = offset_x * sizeof (dtype);
860- reg_sub.xetla_format <load_dtype>() =
861- xetla_load_local<load_dtype, 64 , data_size::default_size>(
862- payload.address + address_offset);
882+ reg_sub.xetla_format <load_dtype>() = xetla_load_local<
883+ load_dtype,
884+ max_load_vec_len,
885+ data_size::default_size>(payload.address + address_offset);
863886 }
864887 }
865- detail::
866- process_1d_tail<load_len % 64 , 32 , detail::process_flag::load, L1, L2>(
867- tile, payload, load_len / 64 * 64 * scale_factor);
888+ detail::process_1d_tail<
889+ load_len % max_load_vec_len,
890+ (max_load_vec_len >> 1 ),
891+ detail::process_flag::load,
892+ L1,
893+ L2>(tile, payload, load_iter_steps * max_load_vec_len * scale_factor);
868894}
869895
870896} // namespace gpu::xetla::subgroup
0 commit comments