Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit 2dbe90c

Browse files
committed
update prefetch
1 parent f6a11b5 commit 2dbe90c

File tree

5 files changed

+433
-414
lines changed

5 files changed

+433
-414
lines changed

include/common/core/arch_config.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ struct load_store_attr_t<msg_type::block_1d, arch_tag> {
119119
static constexpr uint32_t max_aligned_load_vec_len = 256;
120120
static constexpr uint32_t max_store_vec_len = 256;
121121
static constexpr uint32_t max_aligned_store_vec_len = 256;
122-
static constexpr uint32_t max_prefetch_vec_len = 32;
122+
static constexpr uint32_t max_prefetch_vec_len = 256;
123123
static constexpr uint32_t max_channel_num = 16;
124124
};
125125

include/subgroup/tile/impl/load_xe.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,10 @@ tile_load(tile_t& tile, payload_t& payload) {
214214
// arch_tag>(tdesc);
215215
reg_tmp.xetla_format<native_type_t<load_dtype>>() = xetla_load_global<
216216
native_type_t<load_dtype>,
217-
block_size_x / scale_factor,
218-
block_size_y,
217+
(mem_transpose ? ld_blk_size_y : block_size_x) / scale_factor,
218+
(mem_transpose ? block_size_x : ld_blk_size_y),
219+
// block_size_x / scale_factor,
220+
// ld_blk_size_y,
219221
arr_len,
220222
trans,
221223
mem_transform,

include/subgroup/tile/impl/payload_xe.hpp

Lines changed: 94 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -1170,10 +1170,9 @@ struct mem_payload_t<
11701170
static constexpr uint32_t tile_size_y = tile_desc::tile_size_y;
11711171
static constexpr uint32_t block_size_x = tile_desc::block_size_x;
11721172
static constexpr uint32_t block_size_y = tile_desc::block_size_y;
1173-
static constexpr uint32_t tile_bytes =
1174-
tile_size_x * tile_size_y * sizeof(dtype);
1173+
static constexpr uint32_t tile_bytes = tile_desc::tile_elems * sizeof(dtype);
11751174
static constexpr uint32_t block_bytes =
1176-
block_size_x * block_size_y * sizeof(dtype);
1175+
tile_desc::block_elems * sizeof(dtype);
11771176
using this_payload_t =
11781177
mem_payload_t<mem_desc_t, tile_desc, msg_type::block_2d, arch_tag_>;
11791178

@@ -1250,7 +1249,7 @@ struct mem_payload_t<
12501249
base_offset = mem_transpose
12511250
? base_x * pitch_in_bytes + base_y * sizeof(dtype)
12521251
: base_y * pitch_in_bytes + base_x * sizeof(dtype);
1253-
base_ptr = (mem_dtype*)mem_tdesc.base.base;
1252+
base_ptr = reinterpret_cast<mem_dtype*>(mem_tdesc.base.base);
12541253

12551254
xetla_vector<uint32_t, num_channel> channel_index =
12561255
xetla_vector_gen<uint32_t, num_channel>(0, 1);
@@ -1709,11 +1708,12 @@ struct prefetch_payload_t<
17091708
reg_layout_>,
17101709
num_coop_sg_,
17111710
arch_tag_,
1712-
std::enable_if_t<(!arch_has_2d_load_store<arch_tag_>)&&(
1713-
((block_size_y_ != 1 || tile_size_y_ != 1) &&
1714-
mem_layout_ == mem_layout::row_major) ||
1715-
((block_size_x_ != 1 || tile_size_x_ != 1) &&
1716-
mem_layout_ == mem_layout::col_major))>> {
1711+
std::enable_if_t<
1712+
(!arch_has_2d_load_store<arch_tag_>) &&
1713+
(((block_size_y_ != 1 || tile_size_y_ != 1) &&
1714+
mem_layout_ == mem_layout::row_major) ||
1715+
((block_size_x_ != 1 || tile_size_x_ != 1) &&
1716+
mem_layout_ == mem_layout::col_major))>> {
17171717
using dtype = native_type_t<dtype_>;
17181718
using mem_desc_t =
17191719
mem_desc_t<dtype_, mem_layout_, mem_space::global, alignment_>;
@@ -1734,10 +1734,8 @@ struct prefetch_payload_t<
17341734
static constexpr uint32_t tile_size_y = tile_desc::tile_size_y;
17351735
static constexpr uint32_t block_size_x = tile_desc::block_size_x;
17361736
static constexpr uint32_t block_size_y = tile_desc::block_size_y;
1737-
static constexpr uint32_t tile_bytes =
1738-
tile_size_x * tile_size_y * sizeof(dtype);
1739-
static constexpr uint32_t block_bytes =
1740-
block_size_x * block_size_y * sizeof(dtype);
1737+
static constexpr uint32_t tile_bytes = tile_desc::block_elems * sizeof(dtype);
1738+
static constexpr uint32_t block_bytes = tile_desc::tile_elems * sizeof(dtype);
17411739

17421740
private:
17431741
using this_payload_t =
@@ -1751,67 +1749,75 @@ struct prefetch_payload_t<
17511749
static constexpr bool trans = (mem_transpose ^ reg_transpose) &&
17521750
!(std::is_same_v<dtype_, int4x2> || std::is_same_v<dtype_, int4x8>);
17531751

1754-
using prefetch_dtype = typename std::conditional<
1752+
using prefetch_dtype = typename std::conditional_t<
17551753
(alignment_in_bytes % (sizeof(uint64_t)) == 0),
17561754
uint64_t,
1757-
typename std::conditional<
1755+
typename std::conditional_t<
17581756
(alignment_in_bytes % (sizeof(uint32_t)) == 0),
17591757
uint32_t,
1760-
dtype>::type>::type;
1758+
dtype>>;
17611759
static constexpr uint32_t pack_factor =
17621760
sizeof(prefetch_dtype) / sizeof(dtype);
17631761

1764-
static constexpr uint32_t min_store_bytes = 16 * sizeof(dtype);
1765-
static constexpr uint32_t max_store_bytes = 32 * sizeof(dtype);
1766-
static constexpr uint32_t simd_channel =
1767-
((tile_bytes % max_store_bytes) == 0 &&
1768-
(block_bytes % max_store_bytes) == 0)
1769-
? 32
1770-
: 16;
1771-
static constexpr uint32_t num_channel = mem_transpose
1772-
? (simd_channel >= block_size_x) ? block_size_x : simd_channel
1773-
: (simd_channel >= block_size_y) ? block_size_y
1774-
: simd_channel;
1762+
static constexpr uint32_t vector_size =
1763+
((mem_transpose ? block_size_y : block_size_x) + pack_factor - 1) /
1764+
pack_factor;
17751765

1776-
static constexpr uint32_t vector_size = mem_transpose
1777-
? (block_size_y + pack_factor - 1) / pack_factor
1778-
: (block_size_x + pack_factor - 1) / pack_factor;
1766+
using load_store_attr = load_store_attr_t<msg_type::block_1d, arch_tag>;
1767+
static constexpr uint32_t max_prefetch_vec_len =
1768+
load_store_attr::max_prefetch_vec_len;
17791769

1780-
static constexpr uint32_t mem_tile_size_w =
1781-
mem_transpose ? tile_size_y : tile_size_x;
1782-
static constexpr uint32_t mem_tile_size_h =
1783-
mem_transpose ? tile_size_x : tile_size_y;
1784-
using load_store_attr =
1785-
typename arch_attr_t<arch_tag>::template load_store_attr<message_type>;
1786-
static constexpr uint32_t special_prefetch_width =
1787-
load_store_attr::special_prefetch_width_in_bytes / sizeof(dtype);
1788-
static constexpr uint32_t normal_prefetch_width =
1789-
load_store_attr::max_load_width_in_bytes / sizeof(dtype);
1790-
static constexpr bool is_special_prefetch =
1791-
(mem_tile_size_w % special_prefetch_width) == 0;
1770+
static constexpr uint32_t max_channel =
1771+
max_prefetch_vec_len / (vector_size * sizeof(prefetch_dtype));
17921772

1793-
static constexpr uint32_t block_size_w = is_special_prefetch
1794-
? special_prefetch_width
1795-
: (normal_prefetch_width > mem_tile_size_w ? mem_tile_size_w
1796-
: normal_prefetch_width);
1797-
static constexpr uint32_t block_size_h =
1798-
load_store_attr::max_load_height_in_elem;
1799-
// could have over-prefetch, but that's should be fine
1800-
static constexpr uint32_t max_num_block_w =
1801-
(mem_tile_size_w + block_size_w - 1) / block_size_w;
1802-
static constexpr uint32_t num_coop_sg = num_coop_sg_;
1803-
static constexpr uint32_t num_coop_sg_w =
1804-
detail::gcd<num_coop_sg, max_num_block_w>::value;
1805-
static constexpr uint32_t num_coop_sg_h = num_coop_sg / num_coop_sg_w;
1773+
static constexpr uint32_t select_channel(const uint32_t channel) {
1774+
return (channel >= load_store_attr::max_channel_num)
1775+
? load_store_attr::max_channel_num
1776+
: channel >= 16 ? 16
1777+
: channel >= 8 ? 8
1778+
: 1;
1779+
}
18061780

1807-
static constexpr uint32_t num_block_w = max_num_block_w / num_coop_sg_w;
1808-
static constexpr uint32_t tile_size_w = block_size_w * num_block_w;
1809-
static constexpr uint32_t tile_size_h =
1810-
(mem_tile_size_h + num_coop_sg_h - 1) / num_coop_sg_h;
1811-
static constexpr uint32_t num_block_h =
1812-
(tile_size_h + block_size_h - 1) / block_size_h;
1781+
static constexpr uint32_t num_channel = select_channel(
1782+
std::min(mem_transpose ? block_size_x : block_size_y, max_channel));
1783+
1784+
// static constexpr uint32_t mem_tile_size_w =
1785+
// mem_transpose ? tile_size_y : tile_size_x;
1786+
// static constexpr uint32_t mem_tile_size_h =
1787+
// mem_transpose ? tile_size_x : tile_size_y;
1788+
1789+
// static constexpr uint32_t special_prefetch_width =
1790+
// load_store_attr::special_prefetch_width_in_bytes / sizeof(dtype);
1791+
// static constexpr uint32_t normal_prefetch_width =
1792+
// load_store_attr::max_load_width_in_bytes / sizeof(dtype);
1793+
// static constexpr bool is_special_prefetch =
1794+
// (mem_tile_size_w % special_prefetch_width) == 0;
1795+
1796+
// static constexpr uint32_t block_size_w = is_special_prefetch
1797+
// ? special_prefetch_width
1798+
// : (normal_prefetch_width > mem_tile_size_w ? mem_tile_size_w
1799+
// : normal_prefetch_width);
1800+
// static constexpr uint32_t block_size_h =
1801+
// load_store_attr::max_load_height_in_elem;
1802+
// // could have over-prefetch, but that's should be fine
1803+
// static constexpr uint32_t max_num_block_w =
1804+
// (mem_tile_size_w + block_size_w - 1) / block_size_w;
1805+
// static constexpr uint32_t num_coop_sg = num_coop_sg_;
1806+
// static constexpr uint32_t num_coop_sg_w =
1807+
// detail::gcd<num_coop_sg, max_num_block_w>::value;
1808+
// static constexpr uint32_t num_coop_sg_h = num_coop_sg / num_coop_sg_w;
1809+
1810+
// static constexpr uint32_t num_block_w = max_num_block_w / num_coop_sg_w;
1811+
// static constexpr uint32_t tile_size_w = block_size_w * num_block_w;
1812+
// static constexpr uint32_t tile_size_h =
1813+
// (mem_tile_size_h + num_coop_sg_h - 1) / num_coop_sg_h;
1814+
// static constexpr uint32_t num_block_h =
1815+
// (tile_size_h + block_size_h - 1) / block_size_h;
18131816

18141817
xetla_vector<uint32_t, num_channel> channel_offset;
1818+
xetla_vector<uint32_t, num_channel> step_x;
1819+
xetla_vector<uint32_t, num_channel> step_y;
1820+
18151821
uint64_t base_offset;
18161822
uint32_t base_x;
18171823
uint32_t base_y;
@@ -1848,13 +1854,15 @@ struct prefetch_payload_t<
18481854
return *this;
18491855
}
18501856

1851-
inline prefetch_payload_t(mem_desc_t& mem_desc, uint32_t coop_id = 0) {
1852-
uint32_t coop_id_x = coop_id % num_coop_sg_w;
1853-
uint32_t coop_id_y = coop_id / num_coop_sg_w;
1857+
inline prefetch_payload_t(
1858+
mem_desc_t& mem_desc,
1859+
[[maybe_unused]] uint32_t coop_id = 0) {
1860+
// uint32_t coop_id_x = coop_id % num_coop_sg_w;
1861+
// uint32_t coop_id_y = coop_id / num_coop_sg_w;
18541862

18551863
pitch_in_bytes = mem_desc.shape.stride * sizeof(dtype);
1856-
base_x = mem_desc.coord.x + coop_id_x * tile_size_w;
1857-
base_y = mem_desc.coord.y + coop_id_y * tile_size_h;
1864+
base_x = mem_desc.coord.x;
1865+
base_y = mem_desc.coord.y;
18581866
width_in_elems = mem_desc.shape.x;
18591867
height_in_elems = mem_desc.shape.y;
18601868
base_offset = mem_transpose
@@ -1874,13 +1882,15 @@ struct prefetch_payload_t<
18741882
int surface_pitch,
18751883
int surface_offset_x,
18761884
int surface_offset_y,
1877-
uint32_t coop_id = 0) {
1878-
uint32_t coop_id_x = coop_id % num_coop_sg_w;
1879-
uint32_t coop_id_y = coop_id / num_coop_sg_w;
1885+
[[maybe_unused]] uint32_t coop_id = 0) {
1886+
// uint32_t coop_id_x = coop_id % num_coop_sg_w;
1887+
// uint32_t coop_id_y = coop_id / num_coop_sg_w;
1888+
// base_x = surface_offset_x + coop_id_x * tile_size_w;
1889+
// base_y = surface_offset_y + coop_id_y * tile_size_h;
18801890

18811891
pitch_in_bytes = surface_pitch * sizeof(dtype);
1882-
base_x = surface_offset_x + coop_id_x * tile_size_w;
1883-
base_y = surface_offset_y + coop_id_y * tile_size_h;
1892+
base_x = surface_offset_x;
1893+
base_y = surface_offset_y;
18841894
width_in_elems = surface_width;
18851895
height_in_elems = surface_height;
18861896
base_offset = mem_transpose
@@ -1893,13 +1903,17 @@ struct prefetch_payload_t<
18931903
channel_offset = channel_index * pitch_in_bytes;
18941904
}
18951905

1896-
inline void init(mem_desc_t& mem_desc, uint32_t coop_id = 0) {
1897-
uint32_t coop_id_x = coop_id % num_coop_sg_w;
1898-
uint32_t coop_id_y = coop_id / num_coop_sg_w;
1906+
inline void init(
1907+
mem_desc_t& mem_desc,
1908+
[[maybe_unused]] uint32_t coop_id = 0) {
1909+
// uint32_t coop_id_x = coop_id % num_coop_sg_w;
1910+
// uint32_t coop_id_y = coop_id / num_coop_sg_w;
1911+
// base_x = mem_desc.coord.x + coop_id_x * tile_size_w;
1912+
// base_y = mem_desc.coord.y + coop_id_y * tile_size_h;
18991913

19001914
pitch_in_bytes = mem_desc.shape.stride * sizeof(dtype);
1901-
base_x = mem_desc.coord.x + coop_id_x * tile_size_w;
1902-
base_y = mem_desc.coord.y + coop_id_y * tile_size_h;
1915+
base_x = mem_desc.coord.x;
1916+
base_y = mem_desc.coord.y;
19031917
width_in_elems = mem_desc.shape.x;
19041918
height_in_elems = mem_desc.shape.y;
19051919
base_offset = mem_transpose
@@ -1955,9 +1969,10 @@ struct prefetch_payload_t<
19551969
reg_layout_>,
19561970
num_coop_sg_,
19571971
arch_tag_,
1958-
std::enable_if_t<(arch_has_2d_load_store<arch_tag_>)&&(
1959-
((tile_size_y_ != 1) && mem_layout_ == mem_layout::row_major) ||
1960-
((tile_size_x_ != 1) && mem_layout_ == mem_layout::col_major))>> {
1972+
std::enable_if_t<
1973+
(arch_has_2d_load_store<arch_tag_>) &&
1974+
(((tile_size_y_ != 1) && mem_layout_ == mem_layout::row_major) ||
1975+
((tile_size_x_ != 1) && mem_layout_ == mem_layout::col_major))>> {
19611976
using dtype = dtype_;
19621977
using mem_desc_t =
19631978
mem_desc_t<dtype_, mem_layout_, mem_space::global, alignment_>;

include/subgroup/tile/impl/prefetch_xe.hpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,7 @@ tile_prefetch(payload_t& payload) {
104104
using prefetch_dtype = typename payload_t::prefetch_dtype;
105105
constexpr uint32_t num_channel = payload_t::num_channel;
106106
#pragma unroll
107-
for (uint32_t i = 0; i < tile_desc::tile_size_y / tile_desc::block_size_y;
108-
i++) {
107+
for (uint32_t i = 0; i < tile_desc::num_block_y; i++) {
109108
uint32_t offset_y = i * tile_desc::block_size_y;
110109
#pragma unroll
111110
for (uint32_t j = 0; j < tile_desc::num_block_x; j++) {
@@ -126,7 +125,6 @@ tile_prefetch(payload_t& payload) {
126125
L2>(
127126
payload.base_ptr,
128127
payload.channel_offset + payload.base_offset + address_offset);
129-
// }
130128
}
131129
}
132130
}

0 commit comments

Comments
 (0)