@@ -1170,10 +1170,9 @@ struct mem_payload_t<
11701170 static constexpr uint32_t tile_size_y = tile_desc::tile_size_y;
11711171 static constexpr uint32_t block_size_x = tile_desc::block_size_x;
11721172 static constexpr uint32_t block_size_y = tile_desc::block_size_y;
1173- static constexpr uint32_t tile_bytes =
1174- tile_size_x * tile_size_y * sizeof (dtype);
1173+ static constexpr uint32_t tile_bytes = tile_desc::tile_elems * sizeof (dtype);
11751174 static constexpr uint32_t block_bytes =
1176- block_size_x * block_size_y * sizeof (dtype);
1175+ tile_desc::block_elems * sizeof (dtype);
11771176 using this_payload_t =
11781177 mem_payload_t <mem_desc_t , tile_desc, msg_type::block_2d, arch_tag_>;
11791178
@@ -1250,7 +1249,7 @@ struct mem_payload_t<
12501249 base_offset = mem_transpose
12511250 ? base_x * pitch_in_bytes + base_y * sizeof (dtype)
12521251 : base_y * pitch_in_bytes + base_x * sizeof (dtype);
1253- base_ptr = ( mem_dtype*) mem_tdesc.base .base ;
1252+ base_ptr = reinterpret_cast < mem_dtype*>( mem_tdesc.base .base ) ;
12541253
12551254 xetla_vector<uint32_t , num_channel> channel_index =
12561255 xetla_vector_gen<uint32_t , num_channel>(0 , 1 );
@@ -1709,11 +1708,12 @@ struct prefetch_payload_t<
17091708 reg_layout_>,
17101709 num_coop_sg_,
17111710 arch_tag_,
1712- std::enable_if_t <(!arch_has_2d_load_store<arch_tag_>)&&(
1713- ((block_size_y_ != 1 || tile_size_y_ != 1 ) &&
1714- mem_layout_ == mem_layout::row_major) ||
1715- ((block_size_x_ != 1 || tile_size_x_ != 1 ) &&
1716- mem_layout_ == mem_layout::col_major))>> {
1711+ std::enable_if_t <
1712+ (!arch_has_2d_load_store<arch_tag_>) &&
1713+ (((block_size_y_ != 1 || tile_size_y_ != 1 ) &&
1714+ mem_layout_ == mem_layout::row_major) ||
1715+ ((block_size_x_ != 1 || tile_size_x_ != 1 ) &&
1716+ mem_layout_ == mem_layout::col_major))>> {
17171717 using dtype = native_type_t <dtype_>;
17181718 using mem_desc_t =
17191719 mem_desc_t <dtype_, mem_layout_, mem_space::global, alignment_>;
@@ -1734,10 +1734,8 @@ struct prefetch_payload_t<
17341734 static constexpr uint32_t tile_size_y = tile_desc::tile_size_y;
17351735 static constexpr uint32_t block_size_x = tile_desc::block_size_x;
17361736 static constexpr uint32_t block_size_y = tile_desc::block_size_y;
1737- static constexpr uint32_t tile_bytes =
1738- tile_size_x * tile_size_y * sizeof (dtype);
1739- static constexpr uint32_t block_bytes =
1740- block_size_x * block_size_y * sizeof (dtype);
1737+ static constexpr uint32_t tile_bytes = tile_desc::block_elems * sizeof (dtype);
1738+ static constexpr uint32_t block_bytes = tile_desc::tile_elems * sizeof (dtype);
17411739
17421740 private:
17431741 using this_payload_t =
@@ -1751,67 +1749,75 @@ struct prefetch_payload_t<
17511749 static constexpr bool trans = (mem_transpose ^ reg_transpose) &&
17521750 !(std::is_same_v<dtype_, int4x2> || std::is_same_v<dtype_, int4x8>);
17531751
1754- using prefetch_dtype = typename std::conditional <
1752+ using prefetch_dtype = typename std::conditional_t <
17551753 (alignment_in_bytes % (sizeof (uint64_t )) == 0 ),
17561754 uint64_t ,
1757- typename std::conditional <
1755+ typename std::conditional_t <
17581756 (alignment_in_bytes % (sizeof (uint32_t )) == 0 ),
17591757 uint32_t ,
1760- dtype>::type>::type ;
1758+ dtype>> ;
17611759 static constexpr uint32_t pack_factor =
17621760 sizeof (prefetch_dtype) / sizeof (dtype);
17631761
1764- static constexpr uint32_t min_store_bytes = 16 * sizeof (dtype);
1765- static constexpr uint32_t max_store_bytes = 32 * sizeof (dtype);
1766- static constexpr uint32_t simd_channel =
1767- ((tile_bytes % max_store_bytes) == 0 &&
1768- (block_bytes % max_store_bytes) == 0 )
1769- ? 32
1770- : 16 ;
1771- static constexpr uint32_t num_channel = mem_transpose
1772- ? (simd_channel >= block_size_x) ? block_size_x : simd_channel
1773- : (simd_channel >= block_size_y) ? block_size_y
1774- : simd_channel;
1762+ static constexpr uint32_t vector_size =
1763+ ((mem_transpose ? block_size_y : block_size_x) + pack_factor - 1 ) /
1764+ pack_factor;
17751765
1776- static constexpr uint32_t vector_size = mem_transpose
1777- ? (block_size_y + pack_factor - 1 ) / pack_factor
1778- : (block_size_x + pack_factor - 1 ) / pack_factor ;
1766+ using load_store_attr = load_store_attr_t <msg_type::block_1d, arch_tag>;
1767+ static constexpr uint32_t max_prefetch_vec_len =
1768+ load_store_attr::max_prefetch_vec_len ;
17791769
1780- static constexpr uint32_t mem_tile_size_w =
1781- mem_transpose ? tile_size_y : tile_size_x;
1782- static constexpr uint32_t mem_tile_size_h =
1783- mem_transpose ? tile_size_x : tile_size_y;
1784- using load_store_attr =
1785- typename arch_attr_t <arch_tag>::template load_store_attr<message_type>;
1786- static constexpr uint32_t special_prefetch_width =
1787- load_store_attr::special_prefetch_width_in_bytes / sizeof (dtype);
1788- static constexpr uint32_t normal_prefetch_width =
1789- load_store_attr::max_load_width_in_bytes / sizeof (dtype);
1790- static constexpr bool is_special_prefetch =
1791- (mem_tile_size_w % special_prefetch_width) == 0 ;
1770+ static constexpr uint32_t max_channel =
1771+ max_prefetch_vec_len / (vector_size * sizeof (prefetch_dtype));
17921772
1793- static constexpr uint32_t block_size_w = is_special_prefetch
1794- ? special_prefetch_width
1795- : (normal_prefetch_width > mem_tile_size_w ? mem_tile_size_w
1796- : normal_prefetch_width);
1797- static constexpr uint32_t block_size_h =
1798- load_store_attr::max_load_height_in_elem;
1799- // could have over-prefetch, but that's should be fine
1800- static constexpr uint32_t max_num_block_w =
1801- (mem_tile_size_w + block_size_w - 1 ) / block_size_w;
1802- static constexpr uint32_t num_coop_sg = num_coop_sg_;
1803- static constexpr uint32_t num_coop_sg_w =
1804- detail::gcd<num_coop_sg, max_num_block_w>::value;
1805- static constexpr uint32_t num_coop_sg_h = num_coop_sg / num_coop_sg_w;
1773+ static constexpr uint32_t select_channel (const uint32_t channel) {
1774+ return (channel >= load_store_attr::max_channel_num)
1775+ ? load_store_attr::max_channel_num
1776+ : channel >= 16 ? 16
1777+ : channel >= 8 ? 8
1778+ : 1 ;
1779+ }
18061780
1807- static constexpr uint32_t num_block_w = max_num_block_w / num_coop_sg_w;
1808- static constexpr uint32_t tile_size_w = block_size_w * num_block_w;
1809- static constexpr uint32_t tile_size_h =
1810- (mem_tile_size_h + num_coop_sg_h - 1 ) / num_coop_sg_h;
1811- static constexpr uint32_t num_block_h =
1812- (tile_size_h + block_size_h - 1 ) / block_size_h;
1781+ static constexpr uint32_t num_channel = select_channel(
1782+ std::min (mem_transpose ? block_size_x : block_size_y, max_channel));
1783+
1784+ // static constexpr uint32_t mem_tile_size_w =
1785+ // mem_transpose ? tile_size_y : tile_size_x;
1786+ // static constexpr uint32_t mem_tile_size_h =
1787+ // mem_transpose ? tile_size_x : tile_size_y;
1788+
1789+ // static constexpr uint32_t special_prefetch_width =
1790+ // load_store_attr::special_prefetch_width_in_bytes / sizeof(dtype);
1791+ // static constexpr uint32_t normal_prefetch_width =
1792+ // load_store_attr::max_load_width_in_bytes / sizeof(dtype);
1793+ // static constexpr bool is_special_prefetch =
1794+ // (mem_tile_size_w % special_prefetch_width) == 0;
1795+
1796+ // static constexpr uint32_t block_size_w = is_special_prefetch
1797+ // ? special_prefetch_width
1798+ // : (normal_prefetch_width > mem_tile_size_w ? mem_tile_size_w
1799+ // : normal_prefetch_width);
1800+ // static constexpr uint32_t block_size_h =
1801+ // load_store_attr::max_load_height_in_elem;
1802+ // // could have over-prefetch, but that's should be fine
1803+ // static constexpr uint32_t max_num_block_w =
1804+ // (mem_tile_size_w + block_size_w - 1) / block_size_w;
1805+ // static constexpr uint32_t num_coop_sg = num_coop_sg_;
1806+ // static constexpr uint32_t num_coop_sg_w =
1807+ // detail::gcd<num_coop_sg, max_num_block_w>::value;
1808+ // static constexpr uint32_t num_coop_sg_h = num_coop_sg / num_coop_sg_w;
1809+
1810+ // static constexpr uint32_t num_block_w = max_num_block_w / num_coop_sg_w;
1811+ // static constexpr uint32_t tile_size_w = block_size_w * num_block_w;
1812+ // static constexpr uint32_t tile_size_h =
1813+ // (mem_tile_size_h + num_coop_sg_h - 1) / num_coop_sg_h;
1814+ // static constexpr uint32_t num_block_h =
1815+ // (tile_size_h + block_size_h - 1) / block_size_h;
18131816
18141817 xetla_vector<uint32_t , num_channel> channel_offset;
1818+ xetla_vector<uint32_t , num_channel> step_x;
1819+ xetla_vector<uint32_t , num_channel> step_y;
1820+
18151821 uint64_t base_offset;
18161822 uint32_t base_x;
18171823 uint32_t base_y;
@@ -1848,13 +1854,15 @@ struct prefetch_payload_t<
18481854 return *this ;
18491855 }
18501856
1851- inline prefetch_payload_t (mem_desc_t & mem_desc, uint32_t coop_id = 0 ) {
1852- uint32_t coop_id_x = coop_id % num_coop_sg_w;
1853- uint32_t coop_id_y = coop_id / num_coop_sg_w;
1857+ inline prefetch_payload_t (
1858+ mem_desc_t & mem_desc,
1859+ [[maybe_unused]] uint32_t coop_id = 0 ) {
1860+ // uint32_t coop_id_x = coop_id % num_coop_sg_w;
1861+ // uint32_t coop_id_y = coop_id / num_coop_sg_w;
18541862
18551863 pitch_in_bytes = mem_desc.shape .stride * sizeof (dtype);
1856- base_x = mem_desc.coord .x + coop_id_x * tile_size_w ;
1857- base_y = mem_desc.coord .y + coop_id_y * tile_size_h ;
1864+ base_x = mem_desc.coord .x ;
1865+ base_y = mem_desc.coord .y ;
18581866 width_in_elems = mem_desc.shape .x ;
18591867 height_in_elems = mem_desc.shape .y ;
18601868 base_offset = mem_transpose
@@ -1874,13 +1882,15 @@ struct prefetch_payload_t<
18741882 int surface_pitch,
18751883 int surface_offset_x,
18761884 int surface_offset_y,
1877- uint32_t coop_id = 0 ) {
1878- uint32_t coop_id_x = coop_id % num_coop_sg_w;
1879- uint32_t coop_id_y = coop_id / num_coop_sg_w;
1885+ [[maybe_unused]] uint32_t coop_id = 0 ) {
1886+ // uint32_t coop_id_x = coop_id % num_coop_sg_w;
1887+ // uint32_t coop_id_y = coop_id / num_coop_sg_w;
1888+ // base_x = surface_offset_x + coop_id_x * tile_size_w;
1889+ // base_y = surface_offset_y + coop_id_y * tile_size_h;
18801890
18811891 pitch_in_bytes = surface_pitch * sizeof (dtype);
1882- base_x = surface_offset_x + coop_id_x * tile_size_w ;
1883- base_y = surface_offset_y + coop_id_y * tile_size_h ;
1892+ base_x = surface_offset_x;
1893+ base_y = surface_offset_y;
18841894 width_in_elems = surface_width;
18851895 height_in_elems = surface_height;
18861896 base_offset = mem_transpose
@@ -1893,13 +1903,17 @@ struct prefetch_payload_t<
18931903 channel_offset = channel_index * pitch_in_bytes;
18941904 }
18951905
1896- inline void init (mem_desc_t & mem_desc, uint32_t coop_id = 0 ) {
1897- uint32_t coop_id_x = coop_id % num_coop_sg_w;
1898- uint32_t coop_id_y = coop_id / num_coop_sg_w;
1906+ inline void init (
1907+ mem_desc_t & mem_desc,
1908+ [[maybe_unused]] uint32_t coop_id = 0 ) {
1909+ // uint32_t coop_id_x = coop_id % num_coop_sg_w;
1910+ // uint32_t coop_id_y = coop_id / num_coop_sg_w;
1911+ // base_x = mem_desc.coord.x + coop_id_x * tile_size_w;
1912+ // base_y = mem_desc.coord.y + coop_id_y * tile_size_h;
18991913
19001914 pitch_in_bytes = mem_desc.shape .stride * sizeof (dtype);
1901- base_x = mem_desc.coord .x + coop_id_x * tile_size_w ;
1902- base_y = mem_desc.coord .y + coop_id_y * tile_size_h ;
1915+ base_x = mem_desc.coord .x ;
1916+ base_y = mem_desc.coord .y ;
19031917 width_in_elems = mem_desc.shape .x ;
19041918 height_in_elems = mem_desc.shape .y ;
19051919 base_offset = mem_transpose
@@ -1955,9 +1969,10 @@ struct prefetch_payload_t<
19551969 reg_layout_>,
19561970 num_coop_sg_,
19571971 arch_tag_,
1958- std::enable_if_t <(arch_has_2d_load_store<arch_tag_>)&&(
1959- ((tile_size_y_ != 1 ) && mem_layout_ == mem_layout::row_major) ||
1960- ((tile_size_x_ != 1 ) && mem_layout_ == mem_layout::col_major))>> {
1972+ std::enable_if_t <
1973+ (arch_has_2d_load_store<arch_tag_>) &&
1974+ (((tile_size_y_ != 1 ) && mem_layout_ == mem_layout::row_major) ||
1975+ ((tile_size_x_ != 1 ) && mem_layout_ == mem_layout::col_major))>> {
19611976 using dtype = dtype_;
19621977 using mem_desc_t =
19631978 mem_desc_t <dtype_, mem_layout_, mem_space::global, alignment_>;
0 commit comments