@@ -1204,10 +1204,9 @@ struct mem_payload_t<
12041204 static constexpr uint32_t tile_size_y = tile_desc::tile_size_y;
12051205 static constexpr uint32_t block_size_x = tile_desc::block_size_x;
12061206 static constexpr uint32_t block_size_y = tile_desc::block_size_y;
1207- static constexpr uint32_t tile_bytes =
1208- tile_size_x * tile_size_y * sizeof (dtype);
1207+ static constexpr uint32_t tile_bytes = tile_desc::tile_elems * sizeof (dtype);
12091208 static constexpr uint32_t block_bytes =
1210- block_size_x * block_size_y * sizeof (dtype);
1209+ tile_desc::block_elems * sizeof (dtype);
12111210 using this_payload_t =
12121211 mem_payload_t <mem_desc_t , tile_desc, msg_type::block_2d, arch_tag_>;
12131212
@@ -1284,7 +1283,7 @@ struct mem_payload_t<
12841283 base_offset = mem_transpose
12851284 ? base_x * pitch_in_bytes + base_y * sizeof (dtype)
12861285 : base_y * pitch_in_bytes + base_x * sizeof (dtype);
1287- base_ptr = ( mem_dtype*) mem_tdesc.base .base ;
1286+ base_ptr = reinterpret_cast < mem_dtype*>( mem_tdesc.base .base ) ;
12881287
12891288 xetla_vector<uint32_t , num_channel> channel_index =
12901289 xetla_vector_gen<uint32_t , num_channel>(0 , 1 );
@@ -1789,10 +1788,8 @@ struct prefetch_payload_t<
17891788 static constexpr uint32_t tile_size_y = tile_desc::tile_size_y;
17901789 static constexpr uint32_t block_size_x = tile_desc::block_size_x;
17911790 static constexpr uint32_t block_size_y = tile_desc::block_size_y;
1792- static constexpr uint32_t tile_bytes =
1793- tile_size_x * tile_size_y * sizeof (dtype);
1794- static constexpr uint32_t block_bytes =
1795- block_size_x * block_size_y * sizeof (dtype);
1791+ static constexpr uint32_t tile_bytes = tile_desc::block_elems * sizeof (dtype);
1792+ static constexpr uint32_t block_bytes = tile_desc::tile_elems * sizeof (dtype);
17961793
17971794 private:
17981795 using this_payload_t =
@@ -1835,41 +1832,57 @@ struct prefetch_payload_t<
18351832 static constexpr uint32_t num_channel = select_channel(
18361833 std::min (mem_transpose ? block_size_x : block_size_y, max_channel));
18371834
1838- static constexpr uint32_t mem_tile_size_w =
1839- mem_transpose ? tile_size_y : tile_size_x;
1840- static constexpr uint32_t mem_tile_size_h =
1841- mem_transpose ? tile_size_x : tile_size_y;
1842- using load_store_attr =
1843- typename arch_attr_t <arch_tag>::template load_store_attr<message_type>;
1844- static constexpr uint32_t special_prefetch_width =
1845- load_store_attr::special_prefetch_width_in_bytes / sizeof (dtype);
1846- static constexpr uint32_t normal_prefetch_width =
1847- load_store_attr::max_load_width_in_bytes / sizeof (dtype);
1848- static constexpr bool is_special_prefetch =
1849- (mem_tile_size_w % special_prefetch_width) == 0 ;
1835+ static constexpr uint32_t max_channel =
1836+ max_prefetch_vec_len / (vector_size * sizeof (prefetch_dtype));
18501837
1851- static constexpr uint32_t block_size_w = is_special_prefetch
1852- ? special_prefetch_width
1853- : (normal_prefetch_width > mem_tile_size_w ? mem_tile_size_w
1854- : normal_prefetch_width);
1855- static constexpr uint32_t block_size_h =
1856- load_store_attr::max_load_height_in_elem;
1857- // could have over-prefetch, but that's should be fine
1858- static constexpr uint32_t max_num_block_w =
1859- (mem_tile_size_w + block_size_w - 1 ) / block_size_w;
1860- static constexpr uint32_t num_coop_sg = num_coop_sg_;
1861- static constexpr uint32_t num_coop_sg_w =
1862- detail::gcd<num_coop_sg, max_num_block_w>::value;
1863- static constexpr uint32_t num_coop_sg_h = num_coop_sg / num_coop_sg_w;
1838+ static constexpr uint32_t select_channel (const uint32_t channel) {
1839+ return (channel >= load_store_attr::max_channel_num)
1840+ ? load_store_attr::max_channel_num
1841+ : channel >= 16 ? 16
1842+ : channel >= 8 ? 8
1843+ : 1 ;
1844+ }
18641845
1865- static constexpr uint32_t num_block_w = max_num_block_w / num_coop_sg_w;
1866- static constexpr uint32_t tile_size_w = block_size_w * num_block_w;
1867- static constexpr uint32_t tile_size_h =
1868- (mem_tile_size_h + num_coop_sg_h - 1 ) / num_coop_sg_h;
1869- static constexpr uint32_t num_block_h =
1870- (tile_size_h + block_size_h - 1 ) / block_size_h;
1846+ static constexpr uint32_t num_channel = select_channel(
1847+ std::min (mem_transpose ? block_size_x : block_size_y, max_channel));
1848+
1849+ // static constexpr uint32_t mem_tile_size_w =
1850+ // mem_transpose ? tile_size_y : tile_size_x;
1851+ // static constexpr uint32_t mem_tile_size_h =
1852+ // mem_transpose ? tile_size_x : tile_size_y;
1853+
1854+ // static constexpr uint32_t special_prefetch_width =
1855+ // load_store_attr::special_prefetch_width_in_bytes / sizeof(dtype);
1856+ // static constexpr uint32_t normal_prefetch_width =
1857+ // load_store_attr::max_load_width_in_bytes / sizeof(dtype);
1858+ // static constexpr bool is_special_prefetch =
1859+ // (mem_tile_size_w % special_prefetch_width) == 0;
1860+
1861+ // static constexpr uint32_t block_size_w = is_special_prefetch
1862+ // ? special_prefetch_width
1863+ // : (normal_prefetch_width > mem_tile_size_w ? mem_tile_size_w
1864+ // : normal_prefetch_width);
1865+ // static constexpr uint32_t block_size_h =
1866+ // load_store_attr::max_load_height_in_elem;
1867+ // // could have over-prefetch, but that's should be fine
1868+ // static constexpr uint32_t max_num_block_w =
1869+ // (mem_tile_size_w + block_size_w - 1) / block_size_w;
1870+ // static constexpr uint32_t num_coop_sg = num_coop_sg_;
1871+ // static constexpr uint32_t num_coop_sg_w =
1872+ // detail::gcd<num_coop_sg, max_num_block_w>::value;
1873+ // static constexpr uint32_t num_coop_sg_h = num_coop_sg / num_coop_sg_w;
1874+
1875+ // static constexpr uint32_t num_block_w = max_num_block_w / num_coop_sg_w;
1876+ // static constexpr uint32_t tile_size_w = block_size_w * num_block_w;
1877+ // static constexpr uint32_t tile_size_h =
1878+ // (mem_tile_size_h + num_coop_sg_h - 1) / num_coop_sg_h;
1879+ // static constexpr uint32_t num_block_h =
1880+ // (tile_size_h + block_size_h - 1) / block_size_h;
18711881
18721882 xetla_vector<uint32_t , num_channel> channel_offset;
1883+ xetla_vector<uint32_t , num_channel> step_x;
1884+ xetla_vector<uint32_t , num_channel> step_y;
1885+
18731886 uint64_t base_offset;
18741887 uint32_t base_x;
18751888 uint32_t base_y;
@@ -1906,13 +1919,15 @@ struct prefetch_payload_t<
19061919 return *this ;
19071920 }
19081921
1909- inline prefetch_payload_t (mem_desc_t & mem_desc, uint32_t coop_id = 0 ) {
1910- uint32_t coop_id_x = coop_id % num_coop_sg_w;
1911- uint32_t coop_id_y = coop_id / num_coop_sg_w;
1922+ inline prefetch_payload_t (
1923+ mem_desc_t & mem_desc,
1924+ [[maybe_unused]] uint32_t coop_id = 0 ) {
1925+ // uint32_t coop_id_x = coop_id % num_coop_sg_w;
1926+ // uint32_t coop_id_y = coop_id / num_coop_sg_w;
19121927
19131928 pitch_in_bytes = mem_desc.shape .stride * sizeof (dtype);
1914- base_x = mem_desc.coord .x + coop_id_x * tile_size_w ;
1915- base_y = mem_desc.coord .y + coop_id_y * tile_size_h ;
1929+ base_x = mem_desc.coord .x ;
1930+ base_y = mem_desc.coord .y ;
19161931 width_in_elems = mem_desc.shape .x ;
19171932 height_in_elems = mem_desc.shape .y ;
19181933 base_offset = mem_transpose
@@ -1932,13 +1947,15 @@ struct prefetch_payload_t<
19321947 int surface_pitch,
19331948 int surface_offset_x,
19341949 int surface_offset_y,
1935- uint32_t coop_id = 0 ) {
1936- uint32_t coop_id_x = coop_id % num_coop_sg_w;
1937- uint32_t coop_id_y = coop_id / num_coop_sg_w;
1950+ [[maybe_unused]] uint32_t coop_id = 0 ) {
1951+ // uint32_t coop_id_x = coop_id % num_coop_sg_w;
1952+ // uint32_t coop_id_y = coop_id / num_coop_sg_w;
1953+ // base_x = surface_offset_x + coop_id_x * tile_size_w;
1954+ // base_y = surface_offset_y + coop_id_y * tile_size_h;
19381955
19391956 pitch_in_bytes = surface_pitch * sizeof (dtype);
1940- base_x = surface_offset_x + coop_id_x * tile_size_w ;
1941- base_y = surface_offset_y + coop_id_y * tile_size_h ;
1957+ base_x = surface_offset_x;
1958+ base_y = surface_offset_y;
19421959 width_in_elems = surface_width;
19431960 height_in_elems = surface_height;
19441961 base_offset = mem_transpose
@@ -1951,13 +1968,17 @@ struct prefetch_payload_t<
19511968 channel_offset = channel_index * pitch_in_bytes;
19521969 }
19531970
1954- inline void init (mem_desc_t & mem_desc, uint32_t coop_id = 0 ) {
1955- uint32_t coop_id_x = coop_id % num_coop_sg_w;
1956- uint32_t coop_id_y = coop_id / num_coop_sg_w;
1971+ inline void init (
1972+ mem_desc_t & mem_desc,
1973+ [[maybe_unused]] uint32_t coop_id = 0 ) {
1974+ // uint32_t coop_id_x = coop_id % num_coop_sg_w;
1975+ // uint32_t coop_id_y = coop_id / num_coop_sg_w;
1976+ // base_x = mem_desc.coord.x + coop_id_x * tile_size_w;
1977+ // base_y = mem_desc.coord.y + coop_id_y * tile_size_h;
19571978
19581979 pitch_in_bytes = mem_desc.shape .stride * sizeof (dtype);
1959- base_x = mem_desc.coord .x + coop_id_x * tile_size_w ;
1960- base_y = mem_desc.coord .y + coop_id_y * tile_size_h ;
1980+ base_x = mem_desc.coord .x ;
1981+ base_y = mem_desc.coord .y ;
19611982 width_in_elems = mem_desc.shape .x ;
19621983 height_in_elems = mem_desc.shape .y ;
19631984 base_offset = mem_transpose
0 commit comments