Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit bd516b2

Browse files
committed
update prefetch
1 parent b3cb404 commit bd516b2

File tree

5 files changed

+414
-389
lines changed

5 files changed

+414
-389
lines changed

include/common/core/arch_config.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ struct load_store_attr_t<msg_type::block_1d, arch_tag> {
119119
static constexpr uint32_t max_aligned_load_vec_len = 256;
120120
static constexpr uint32_t max_store_vec_len = 256;
121121
static constexpr uint32_t max_aligned_store_vec_len = 256;
122-
static constexpr uint32_t max_prefetch_vec_len = 32;
122+
static constexpr uint32_t max_prefetch_vec_len = 256;
123123
static constexpr uint32_t max_channel_num = 16;
124124
};
125125

include/subgroup/tile/impl/load_xe.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,10 @@ tile_load(tile_t& tile, payload_t& payload) {
214214
// arch_tag>(tdesc);
215215
reg_tmp.xetla_format<native_type_t<load_dtype>>() = xetla_load_global<
216216
native_type_t<load_dtype>,
217-
block_size_x / scale_factor,
218-
block_size_y,
217+
(mem_transpose ? ld_blk_size_y : block_size_x) / scale_factor,
218+
(mem_transpose ? block_size_x : ld_blk_size_y),
219+
// block_size_x / scale_factor,
220+
// ld_blk_size_y,
219221
arr_len,
220222
trans,
221223
mem_transform,

include/subgroup/tile/impl/payload_xe.hpp

Lines changed: 75 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1204,10 +1204,9 @@ struct mem_payload_t<
12041204
static constexpr uint32_t tile_size_y = tile_desc::tile_size_y;
12051205
static constexpr uint32_t block_size_x = tile_desc::block_size_x;
12061206
static constexpr uint32_t block_size_y = tile_desc::block_size_y;
1207-
static constexpr uint32_t tile_bytes =
1208-
tile_size_x * tile_size_y * sizeof(dtype);
1207+
static constexpr uint32_t tile_bytes = tile_desc::tile_elems * sizeof(dtype);
12091208
static constexpr uint32_t block_bytes =
1210-
block_size_x * block_size_y * sizeof(dtype);
1209+
tile_desc::block_elems * sizeof(dtype);
12111210
using this_payload_t =
12121211
mem_payload_t<mem_desc_t, tile_desc, msg_type::block_2d, arch_tag_>;
12131212

@@ -1284,7 +1283,7 @@ struct mem_payload_t<
12841283
base_offset = mem_transpose
12851284
? base_x * pitch_in_bytes + base_y * sizeof(dtype)
12861285
: base_y * pitch_in_bytes + base_x * sizeof(dtype);
1287-
base_ptr = (mem_dtype*)mem_tdesc.base.base;
1286+
base_ptr = reinterpret_cast<mem_dtype*>(mem_tdesc.base.base);
12881287

12891288
xetla_vector<uint32_t, num_channel> channel_index =
12901289
xetla_vector_gen<uint32_t, num_channel>(0, 1);
@@ -1789,10 +1788,8 @@ struct prefetch_payload_t<
17891788
static constexpr uint32_t tile_size_y = tile_desc::tile_size_y;
17901789
static constexpr uint32_t block_size_x = tile_desc::block_size_x;
17911790
static constexpr uint32_t block_size_y = tile_desc::block_size_y;
1792-
static constexpr uint32_t tile_bytes =
1793-
tile_size_x * tile_size_y * sizeof(dtype);
1794-
static constexpr uint32_t block_bytes =
1795-
block_size_x * block_size_y * sizeof(dtype);
1791+
static constexpr uint32_t tile_bytes = tile_desc::block_elems * sizeof(dtype);
1792+
static constexpr uint32_t block_bytes = tile_desc::tile_elems * sizeof(dtype);
17961793

17971794
private:
17981795
using this_payload_t =
@@ -1835,41 +1832,57 @@ struct prefetch_payload_t<
18351832
static constexpr uint32_t num_channel = select_channel(
18361833
std::min(mem_transpose ? block_size_x : block_size_y, max_channel));
18371834

1838-
static constexpr uint32_t mem_tile_size_w =
1839-
mem_transpose ? tile_size_y : tile_size_x;
1840-
static constexpr uint32_t mem_tile_size_h =
1841-
mem_transpose ? tile_size_x : tile_size_y;
1842-
using load_store_attr =
1843-
typename arch_attr_t<arch_tag>::template load_store_attr<message_type>;
1844-
static constexpr uint32_t special_prefetch_width =
1845-
load_store_attr::special_prefetch_width_in_bytes / sizeof(dtype);
1846-
static constexpr uint32_t normal_prefetch_width =
1847-
load_store_attr::max_load_width_in_bytes / sizeof(dtype);
1848-
static constexpr bool is_special_prefetch =
1849-
(mem_tile_size_w % special_prefetch_width) == 0;
1835+
static constexpr uint32_t max_channel =
1836+
max_prefetch_vec_len / (vector_size * sizeof(prefetch_dtype));
18501837

1851-
static constexpr uint32_t block_size_w = is_special_prefetch
1852-
? special_prefetch_width
1853-
: (normal_prefetch_width > mem_tile_size_w ? mem_tile_size_w
1854-
: normal_prefetch_width);
1855-
static constexpr uint32_t block_size_h =
1856-
load_store_attr::max_load_height_in_elem;
1857-
// could have over-prefetch, but that's should be fine
1858-
static constexpr uint32_t max_num_block_w =
1859-
(mem_tile_size_w + block_size_w - 1) / block_size_w;
1860-
static constexpr uint32_t num_coop_sg = num_coop_sg_;
1861-
static constexpr uint32_t num_coop_sg_w =
1862-
detail::gcd<num_coop_sg, max_num_block_w>::value;
1863-
static constexpr uint32_t num_coop_sg_h = num_coop_sg / num_coop_sg_w;
1838+
static constexpr uint32_t select_channel(const uint32_t channel) {
1839+
return (channel >= load_store_attr::max_channel_num)
1840+
? load_store_attr::max_channel_num
1841+
: channel >= 16 ? 16
1842+
: channel >= 8 ? 8
1843+
: 1;
1844+
}
18641845

1865-
static constexpr uint32_t num_block_w = max_num_block_w / num_coop_sg_w;
1866-
static constexpr uint32_t tile_size_w = block_size_w * num_block_w;
1867-
static constexpr uint32_t tile_size_h =
1868-
(mem_tile_size_h + num_coop_sg_h - 1) / num_coop_sg_h;
1869-
static constexpr uint32_t num_block_h =
1870-
(tile_size_h + block_size_h - 1) / block_size_h;
1846+
static constexpr uint32_t num_channel = select_channel(
1847+
std::min(mem_transpose ? block_size_x : block_size_y, max_channel));
1848+
1849+
// static constexpr uint32_t mem_tile_size_w =
1850+
// mem_transpose ? tile_size_y : tile_size_x;
1851+
// static constexpr uint32_t mem_tile_size_h =
1852+
// mem_transpose ? tile_size_x : tile_size_y;
1853+
1854+
// static constexpr uint32_t special_prefetch_width =
1855+
// load_store_attr::special_prefetch_width_in_bytes / sizeof(dtype);
1856+
// static constexpr uint32_t normal_prefetch_width =
1857+
// load_store_attr::max_load_width_in_bytes / sizeof(dtype);
1858+
// static constexpr bool is_special_prefetch =
1859+
// (mem_tile_size_w % special_prefetch_width) == 0;
1860+
1861+
// static constexpr uint32_t block_size_w = is_special_prefetch
1862+
// ? special_prefetch_width
1863+
// : (normal_prefetch_width > mem_tile_size_w ? mem_tile_size_w
1864+
// : normal_prefetch_width);
1865+
// static constexpr uint32_t block_size_h =
1866+
// load_store_attr::max_load_height_in_elem;
1867+
// // could have over-prefetch, but that's should be fine
1868+
// static constexpr uint32_t max_num_block_w =
1869+
// (mem_tile_size_w + block_size_w - 1) / block_size_w;
1870+
// static constexpr uint32_t num_coop_sg = num_coop_sg_;
1871+
// static constexpr uint32_t num_coop_sg_w =
1872+
// detail::gcd<num_coop_sg, max_num_block_w>::value;
1873+
// static constexpr uint32_t num_coop_sg_h = num_coop_sg / num_coop_sg_w;
1874+
1875+
// static constexpr uint32_t num_block_w = max_num_block_w / num_coop_sg_w;
1876+
// static constexpr uint32_t tile_size_w = block_size_w * num_block_w;
1877+
// static constexpr uint32_t tile_size_h =
1878+
// (mem_tile_size_h + num_coop_sg_h - 1) / num_coop_sg_h;
1879+
// static constexpr uint32_t num_block_h =
1880+
// (tile_size_h + block_size_h - 1) / block_size_h;
18711881

18721882
xetla_vector<uint32_t, num_channel> channel_offset;
1883+
xetla_vector<uint32_t, num_channel> step_x;
1884+
xetla_vector<uint32_t, num_channel> step_y;
1885+
18731886
uint64_t base_offset;
18741887
uint32_t base_x;
18751888
uint32_t base_y;
@@ -1906,13 +1919,15 @@ struct prefetch_payload_t<
19061919
return *this;
19071920
}
19081921

1909-
inline prefetch_payload_t(mem_desc_t& mem_desc, uint32_t coop_id = 0) {
1910-
uint32_t coop_id_x = coop_id % num_coop_sg_w;
1911-
uint32_t coop_id_y = coop_id / num_coop_sg_w;
1922+
inline prefetch_payload_t(
1923+
mem_desc_t& mem_desc,
1924+
[[maybe_unused]] uint32_t coop_id = 0) {
1925+
// uint32_t coop_id_x = coop_id % num_coop_sg_w;
1926+
// uint32_t coop_id_y = coop_id / num_coop_sg_w;
19121927

19131928
pitch_in_bytes = mem_desc.shape.stride * sizeof(dtype);
1914-
base_x = mem_desc.coord.x + coop_id_x * tile_size_w;
1915-
base_y = mem_desc.coord.y + coop_id_y * tile_size_h;
1929+
base_x = mem_desc.coord.x;
1930+
base_y = mem_desc.coord.y;
19161931
width_in_elems = mem_desc.shape.x;
19171932
height_in_elems = mem_desc.shape.y;
19181933
base_offset = mem_transpose
@@ -1932,13 +1947,15 @@ struct prefetch_payload_t<
19321947
int surface_pitch,
19331948
int surface_offset_x,
19341949
int surface_offset_y,
1935-
uint32_t coop_id = 0) {
1936-
uint32_t coop_id_x = coop_id % num_coop_sg_w;
1937-
uint32_t coop_id_y = coop_id / num_coop_sg_w;
1950+
[[maybe_unused]] uint32_t coop_id = 0) {
1951+
// uint32_t coop_id_x = coop_id % num_coop_sg_w;
1952+
// uint32_t coop_id_y = coop_id / num_coop_sg_w;
1953+
// base_x = surface_offset_x + coop_id_x * tile_size_w;
1954+
// base_y = surface_offset_y + coop_id_y * tile_size_h;
19381955

19391956
pitch_in_bytes = surface_pitch * sizeof(dtype);
1940-
base_x = surface_offset_x + coop_id_x * tile_size_w;
1941-
base_y = surface_offset_y + coop_id_y * tile_size_h;
1957+
base_x = surface_offset_x;
1958+
base_y = surface_offset_y;
19421959
width_in_elems = surface_width;
19431960
height_in_elems = surface_height;
19441961
base_offset = mem_transpose
@@ -1951,13 +1968,17 @@ struct prefetch_payload_t<
19511968
channel_offset = channel_index * pitch_in_bytes;
19521969
}
19531970

1954-
inline void init(mem_desc_t& mem_desc, uint32_t coop_id = 0) {
1955-
uint32_t coop_id_x = coop_id % num_coop_sg_w;
1956-
uint32_t coop_id_y = coop_id / num_coop_sg_w;
1971+
inline void init(
1972+
mem_desc_t& mem_desc,
1973+
[[maybe_unused]] uint32_t coop_id = 0) {
1974+
// uint32_t coop_id_x = coop_id % num_coop_sg_w;
1975+
// uint32_t coop_id_y = coop_id / num_coop_sg_w;
1976+
// base_x = mem_desc.coord.x + coop_id_x * tile_size_w;
1977+
// base_y = mem_desc.coord.y + coop_id_y * tile_size_h;
19571978

19581979
pitch_in_bytes = mem_desc.shape.stride * sizeof(dtype);
1959-
base_x = mem_desc.coord.x + coop_id_x * tile_size_w;
1960-
base_y = mem_desc.coord.y + coop_id_y * tile_size_h;
1980+
base_x = mem_desc.coord.x;
1981+
base_y = mem_desc.coord.y;
19611982
width_in_elems = mem_desc.shape.x;
19621983
height_in_elems = mem_desc.shape.y;
19631984
base_offset = mem_transpose

include/subgroup/tile/impl/prefetch_xe.hpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,7 @@ tile_prefetch(payload_t& payload) {
104104
using prefetch_dtype = typename payload_t::prefetch_dtype;
105105
constexpr uint32_t num_channel = payload_t::num_channel;
106106
#pragma unroll
107-
for (uint32_t i = 0; i < tile_desc::tile_size_y / tile_desc::block_size_y;
108-
i++) {
107+
for (uint32_t i = 0; i < tile_desc::num_block_y; i++) {
109108
uint32_t offset_y = i * tile_desc::block_size_y;
110109
#pragma unroll
111110
for (uint32_t j = 0; j < tile_desc::num_block_x; j++) {
@@ -126,7 +125,6 @@ tile_prefetch(payload_t& payload) {
126125
L2>(
127126
payload.base_ptr,
128127
payload.channel_offset + payload.base_offset + address_offset);
129-
// }
130128
}
131129
}
132130
}

0 commit comments

Comments
 (0)