Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit 510c89e

Browse files
sunjiweiswiftDDEle
authored andcommitted
update API in load/store
1 parent 23191ac commit 510c89e

File tree

8 files changed

+244
-75
lines changed

8 files changed

+244
-75
lines changed

include/common/core/arch_config.hpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,38 @@ template <>
3535
struct load_store_attr_t<msg_type::block_2d, gpu_arch::XeHpc> {
3636
/// HW limitation checks https://gfxspecs.intel.com/Predator/Home/Index/55490
3737
static constexpr bool has_hw_block_2d = true;
38+
// If Transposed and Transformed are both set to false
39+
// BlockHeight must not exceed 32.
3840
static constexpr uint32_t max_load_height_in_elem = 32;
41+
42+
// BlockWidth * NBlocks must not exceed 64 for bytes, 32 for words, 16 for
43+
// dwords, and 8 for qwords.
3944
static constexpr uint32_t max_load_width_in_bytes = 64;
45+
46+
// If Transposed is true then
47+
// BlockWidth must be 1,2,4 for qwords and be in range [1..8] for dwords.
4048
static constexpr uint32_t max_trans_load_width_in_bytes = 32;
49+
50+
// If Transformed is true
51+
// BlockWidth must be in range [4..16] for bytes and [2..16] for word.
4152
static constexpr uint32_t max_vnni_load_width_in_elems = 16;
53+
54+
// BlockHeight must be in range [4..32] for bytes and [2..32] for words.
4255
static constexpr uint32_t min_vnni_load_height_in_bytes = 4;
4356

57+
// BlockHeight must not exceed 8.
4458
static constexpr uint32_t max_store_height_in_elem = 8;
59+
60+
// BlockWidth must not exceed 64 for bytes, 32 for words, 16 for dwords, and 8
61+
// for qwords.
4562
static constexpr uint32_t max_store_width_in_bytes = 64;
4663

64+
// BlockHeight must not exceed 32.
65+
// BlockWidth * NBlocks must not exceed 64 for bytes, 32 for words, 16 for
66+
// dwords, and 8 for qwords.
4767
static constexpr uint32_t max_load_size_in_bytes = 2048;
68+
69+
// BlockWidth * BlockHeight * sizeof(T) must not exceed 512.
4870
static constexpr uint32_t max_store_size_in_bytes = 512;
4971

5072
static constexpr uint32_t special_prefetch_width_in_bytes = 64;

include/common/core/memory.hpp

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -458,8 +458,16 @@ __XETLA_API xetla_vector<T, N> xetla_load_global(
458458
size_t SurfacePitch,
459459
int X,
460460
int Y) {
461-
return __ESIMD_ENS::lsc_load_2d(
462-
Ptr, SurfaceWidth, SurfaceHeight, SurfacePitch, X, Y);
461+
return __ESIMD_ENS::lsc_load_2d<
462+
T,
463+
BlockWidth,
464+
BlockHeight,
465+
NBlocks,
466+
Transposed,
467+
Transformed,
468+
gpu::xetla::detail::get_cache_hint(L1H),
469+
gpu::xetla::detail::get_cache_hint(L2H),
470+
N>(Ptr, SurfaceWidth, SurfaceHeight, SurfacePitch, X, Y);
463471
}
464472

465473
/// simd<T, N> block_load(const T* ptr, size_t byte_offset,
@@ -724,7 +732,12 @@ __XETLA_API void xetla_store_global(
724732
int X,
725733
int Y,
726734
xetla_vector<T, N> Vals) {
727-
__ESIMD_ENS::lsc_store_2d(
735+
__ESIMD_ENS::lsc_store_2d<
736+
T,
737+
BlockWidth,
738+
BlockHeight,
739+
gpu::xetla::detail::get_cache_hint(L1H),
740+
gpu::xetla::detail::get_cache_hint(L2H)>(
728741
Ptr, SurfaceWidth, SurfaceHeight, SurfacePitch, X, Y, Vals);
729742
}
730743
/// template <typename T, int N, int VS = 1, typename OffsetT,

include/common/utils/common.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,8 +275,8 @@ enum class reg_layout : uint8_t {
275275
tiled = 1,
276276
vnni_tiled = 2,
277277
transpose_tiled = 3,
278-
/// this is vnni tiled format, but for each block, they are stored in col
279-
/// major order
278+
/// this is vnni tiled format, but for each block, they are stored in
279+
/// col-major order
280280
vnni_tiled_col_major = 4
281281
};
282282
enum class store_op : uint8_t {

include/subgroup/tile/impl/load_xe.hpp

Lines changed: 108 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,8 @@ tile_load(tile_t& tile, payload_t& payload) {
9494
static constexpr gpu_arch arch_tag = payload_t::arch_tag;
9595

9696
static constexpr reg_layout reg_layout_ = tile_desc::register_layout;
97-
static constexpr bool is_vnni_reverse = payload_t::mem_dword_transpose &&
97+
static constexpr bool is_vnni_reverse =
98+
payload_t::mem_dword_qword_transpose &&
9899
((reg_layout_ == reg_layout::tiled) ||
99100
(reg_layout_ == reg_layout::transpose_tiled));
100101
static constexpr bool reg_transpose = tile_desc::reg_transpose;
@@ -121,28 +122,27 @@ tile_load(tile_t& tile, payload_t& payload) {
121122
mem_transpose ? max_trans_block_width : max_load_block_height;
122123
static constexpr uint32_t ld_blk_size_y = reg_transpose
123124
? block_size_y
124-
: (block_size_y > ld_blk_size_y_limit ? ld_blk_size_y_limit
125-
: block_size_y);
126-
127-
// array len is used to make sure memory load is cache line aligned
128-
// disabled while register or memory transpose
129-
static constexpr uint8_t arr_len_candidate =
130-
(reg_transpose ||
131-
mem_transpose
132-
// block elements should be integer
133-
// times of register bytes
134-
|| ((block_size_y * block_size_x) % elems_per_reg != 0)
135-
// tail blocks also need to meet above condition
136-
||
137-
(((tile_size_y % block_size_y) * block_size_x) % elems_per_reg != 0)) ||
138-
(block_size_y > ld_blk_size_y_limit)
139-
? 1
140-
: (((tile_size_x % elems_per_CL) == 0)
141-
? (((elems_per_CL % block_size_x) == 0)
142-
? elems_per_CL / block_size_x
143-
: 1)
144-
: ((tile_size_x < elems_per_CL) ? (tile_size_x / block_size_x)
145-
: 1));
125+
: std::min(ld_blk_size_y_limit, block_size_y)
126+
127+
// array len is used to make sure memory load is cache line aligned
128+
// disabled while register or memory transpose
129+
static constexpr uint8_t arr_len_candidate =
130+
(reg_transpose ||
131+
mem_transpose
132+
// block elements should be integer
133+
// times of register bytes
134+
|| ((block_size_y * block_size_x) % elems_per_reg != 0)
135+
// tail blocks also need to meet above condition
136+
|| (((tile_size_y % block_size_y) * block_size_x) % elems_per_reg !=
137+
0)) ||
138+
(block_size_y > ld_blk_size_y_limit)
139+
? 1
140+
: (((tile_size_x % elems_per_CL) == 0)
141+
? (((elems_per_CL % block_size_x) == 0)
142+
? elems_per_CL / block_size_x
143+
: 1)
144+
: ((tile_size_x < elems_per_CL) ? (tile_size_x / block_size_x)
145+
: 1));
146146
static constexpr bool is_valid_arr_len_candidate = (arr_len_candidate == 1) ||
147147
(arr_len_candidate == 2) || (arr_len_candidate == 4);
148148

@@ -203,14 +203,31 @@ tile_load(tile_t& tile, payload_t& payload) {
203203
for (uint32_t ii = 0; ii < block_size_y / ld_blk_size_y; ++ii) {
204204
constexpr uint32_t load_elems = ld_blk_size_y * block_size_x * arr_len;
205205

206-
reg_tmp.xetla_format<native_type_t<load_dtype>>() = xetla_tload_global<
206+
// reg_tmp.xetla_format<native_type_t<load_dtype>>() =
207+
// xetla_tload_global<
208+
// load_dtype,
209+
// ld_blk_height * block_size_x * arr_len / scale_factor,
210+
// L1,
211+
// L2,
212+
// trans,
213+
// mem_transform,
214+
// arch_tag>(tdesc);
215+
reg_tmp.xetla_format<native_type_t<load_dtype>>() = xetla_load_global<
207216
load_dtype,
208-
ld_blk_height * block_size_x * arr_len / scale_factor,
209-
L1,
210-
L2,
217+
block_size_x / scale_factor,
218+
block_size_y,
219+
num_block,
211220
trans,
212221
mem_transform,
213-
arch_tag>(tdesc);
222+
L1,
223+
L2>(
224+
(load_dtype*)::gpu::xetla::detail::xetla_get_tensor_base_address(
225+
tdesc),
226+
::gpu::xetla::detail::xetla_get_tensor_width_x(tdesc),
227+
::gpu::xetla::detail::xetla_get_tensor_width_y(tdesc),
228+
::gpu::xetla::detail::xetla_get_tensor_pitch_x(tdesc),
229+
::gpu::xetla::detail::xetla_get_tensor_offset_x(tdesc),
230+
::gpu::xetla::detail::xetla_get_tensor_offset_y(tdesc));
214231
if constexpr (reg_transpose && trans) {
215232
reg_blk.xetla_select<load_elems, 1>(ii * load_elems)
216233
.xetla_format<native_type_t<load_dtype>>() =
@@ -256,14 +273,30 @@ tile_load(tile_t& tile, payload_t& payload) {
256273
tdesc.xetla_format<uint32_t>(), block_widthx_widthy_arrlen);
257274

258275
reg_blk.xetla_select<load_elems, 1>(remained_start)
259-
.xetla_format<native_type_t<load_dtype>>() = xetla_tload_global<
276+
.xetla_format<native_type_t<load_dtype>>() = xetla_load_global<
260277
load_dtype,
261-
(load_elems / scale_factor),
262-
L1,
263-
L2,
278+
block_size_x / scale_factor,
279+
block_size_y,
280+
num_block,
264281
trans,
265282
mem_transform,
266-
arch_tag>(tdesc);
283+
L1,
284+
L2>(
285+
(load_dtype*)::gpu::xetla::detail::xetla_get_tensor_base_address(
286+
tdesc),
287+
::gpu::xetla::detail::xetla_get_tensor_width_x(tdesc),
288+
::gpu::xetla::detail::xetla_get_tensor_width_y(tdesc),
289+
::gpu::xetla::detail::xetla_get_tensor_pitch_x(tdesc),
290+
::gpu::xetla::detail::xetla_get_tensor_offset_x(tdesc),
291+
::gpu::xetla::detail::xetla_get_tensor_offset_y(tdesc));
292+
// xetla_tload_global<
293+
// load_dtype,
294+
// (load_elems / scale_factor),
295+
// L1,
296+
// L2,
297+
// trans,
298+
// mem_transform,
299+
// arch_tag>(tdesc);
267300
}
268301
}
269302
}
@@ -301,14 +334,30 @@ tile_load(tile_t& tile, payload_t& payload) {
301334
constexpr uint32_t load_elems =
302335
remained_ld_blk_size_y * block_size_x * arr_len;
303336

304-
reg_tmp.xetla_format<native_type_t<load_dtype>>() = xetla_tload_global<
337+
reg_tmp.xetla_format<native_type_t<load_dtype>>() = xetla_load_global<
305338
load_dtype,
306-
(ld_blk_height * block_size_x * arr_len / scale_factor),
307-
L1,
308-
L2,
339+
block_size_x / scale_factor,
340+
block_size_y,
341+
num_block,
309342
trans,
310343
mem_transform,
311-
arch_tag>(tdesc);
344+
L1,
345+
L2>(
346+
(load_dtype*)::gpu::xetla::detail::xetla_get_tensor_base_address(
347+
tdesc),
348+
::gpu::xetla::detail::xetla_get_tensor_width_x(tdesc),
349+
::gpu::xetla::detail::xetla_get_tensor_width_y(tdesc),
350+
::gpu::xetla::detail::xetla_get_tensor_pitch_x(tdesc),
351+
::gpu::xetla::detail::xetla_get_tensor_offset_x(tdesc),
352+
::gpu::xetla::detail::xetla_get_tensor_offset_y(tdesc));
353+
// xetla_tload_global<
354+
// load_dtype,
355+
// (ld_blk_height * block_size_x * arr_len / scale_factor),
356+
// L1,
357+
// L2,
358+
// trans,
359+
// mem_transform,
360+
// arch_tag>(tdesc);
312361

313362
if constexpr (reg_transpose && trans) {
314363
reg_blk.xetla_select<load_elems, 1>(ii * load_elems)
@@ -352,14 +401,30 @@ tile_load(tile_t& tile, payload_t& payload) {
352401
gpu::xetla::detail::xetla_set_block_widthx_widthy_arrlen(
353402
tdesc.xetla_format<uint32_t>(), block_widthx_widthy_arrlen);
354403
reg_blk.xetla_select<final_load_elems, 1>(final_start)
355-
.xetla_format<native_type_t<load_dtype>>() = xetla_tload_global<
404+
.xetla_format<native_type_t<load_dtype>>() = xetla_load_global<
356405
load_dtype,
357-
final_load_elems / scale_factor,
358-
L1,
359-
L2,
406+
block_size_x / scale_factor,
407+
block_size_y,
408+
num_block,
360409
trans,
361410
mem_transform,
362-
arch_tag>(tdesc);
411+
L1,
412+
L2>(
413+
(load_dtype*)::gpu::xetla::detail::xetla_get_tensor_base_address(
414+
tdesc),
415+
::gpu::xetla::detail::xetla_get_tensor_width_x(tdesc),
416+
::gpu::xetla::detail::xetla_get_tensor_width_y(tdesc),
417+
::gpu::xetla::detail::xetla_get_tensor_pitch_x(tdesc),
418+
::gpu::xetla::detail::xetla_get_tensor_offset_x(tdesc),
419+
::gpu::xetla::detail::xetla_get_tensor_offset_y(tdesc));
420+
// xetla_tload_global<
421+
// load_dtype,
422+
// final_load_elems / scale_factor,
423+
// L1,
424+
// L2,
425+
// trans,
426+
// mem_transform,
427+
// arch_tag>(tdesc);
363428
}
364429
}
365430
}

include/subgroup/tile/impl/payload_xe.hpp

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,22 +74,46 @@ struct mem_payload_t<
7474
static constexpr bool trans = (mem_transpose ^ reg_transpose) &&
7575
!(std::is_same_v<dtype_, int4x2> || std::is_same_v<dtype_, int4x8>);
7676

77-
static constexpr bool mem_transform = (sizeof(dtype) < 4) && !mem_transpose &&
77+
// Transformed and Transposed cannot be set to true at the same time.
78+
static constexpr bool mem_transform = (sizeof(dtype) <= 2) && !trans &&
7879
(register_layout == reg_layout::vnni_tiled ||
7980
register_layout == reg_layout::vnni_tiled_col_major);
80-
static constexpr bool mem_dword_transpose = (sizeof(dtype) < 4) && trans;
81+
static constexpr bool mem_dword_qword_transpose =
82+
(sizeof(dtype) < 4) && trans;
8183

82-
using mem_dtype =
83-
typename std::conditional<mem_dword_transpose, uint32_t, dtype>::type;
84+
using mem_dtype = typename std::
85+
conditional<mem_dword_qword_transpose, uint32_t, dtype>::type;
8486
static constexpr uint32_t scale_factor = sizeof(mem_dtype) / sizeof(dtype);
87+
mem_dtype* base_ptr;
88+
uint32_t surface_width;
89+
uint32_t surface_height;
90+
uint32_t surface_pitch;
91+
int32_t offset_x;
92+
int32_t offset_y;
8593

8694
xetla_vector<uint32_t, 16 * num_block> payloads;
8795

8896
inline mem_payload_t(const this_payload_t& rhs) {
97+
this->base_ptr = rhs.base_ptr;
98+
this->surface_width = rhs.surface_width;
99+
this->surface_height = rhs.surface_height;
100+
this->surface_pitch = rhs.surface_pitch;
101+
this->offset_x = rhs.offset_x;
102+
this->offset_y = rhs.offset_y;
103+
89104
this->payloads = rhs.payloads;
90105
}
91106

92107
inline mem_payload_t(mem_desc_t& mem_desc) {
108+
this->base_ptr = (mem_dtype*)mem_desc.base.base;
109+
this->surface_width =
110+
(mem_transpose ? mem_desc.shape.y : mem_desc.shape.x) * sizeof(dtype);
111+
this->surface_height =
112+
(mem_transpose ? mem_desc.shape.x : mem_desc.shape.y);
113+
this->surface_pitch = mem_desc.shape.stride * sizeof(dtype);
114+
this->offset_x = mem_desc.coord.x;
115+
this->offset_y = mem_desc.coord.y;
116+
93117
xetla_tdescriptor base_tdesc = mem_desc.get_tdesc();
94118
int32_t offset = gpu::xetla::detail::xetla_get_tensor_offset_x(base_tdesc) /
95119
int32_t(scale_factor);
@@ -105,6 +129,13 @@ struct mem_payload_t<
105129
uint32_t surface_pitch,
106130
int32_t surface_offset_x = 0,
107131
int32_t surface_offset_y = 0) {
132+
this->base_ptr = (mem_dtype)p;
133+
this->surface_width = surface_width;
134+
this->surface_height = surface_height;
135+
this->surface_pitch = surface_pitch;
136+
this->offset_x = surface_offset_x;
137+
this->offset_y = surface_offset_y;
138+
108139
xetla_tdescriptor base_tdesc;
109140
xetla_fill_tdesc(
110141
base_tdesc.xetla_format<uint32_t>(),
@@ -118,6 +149,15 @@ struct mem_payload_t<
118149
}
119150

120151
__XETLA_API void init(mem_desc_t& mem_desc) {
152+
this->base_ptr = (mem_dtype*)mem_desc.base.base;
153+
this->surface_width =
154+
(mem_transpose ? mem_desc.shape.y : mem_desc.shape.x) * sizeof(dtype);
155+
this->surface_height =
156+
(mem_transpose ? mem_desc.shape.x : mem_desc.shape.y);
157+
this->surface_pitch = mem_desc.shape.stride * sizeof(dtype);
158+
this->offset_x = mem_desc.coord.x;
159+
this->offset_y = mem_desc.coord.y;
160+
121161
xetla_tdescriptor base_tdesc = mem_desc.get_tdesc();
122162
int32_t offset = gpu::xetla::detail::xetla_get_tensor_offset_x(base_tdesc) /
123163
int32_t(scale_factor);
@@ -141,6 +181,13 @@ struct mem_payload_t<
141181
uint32_t surface_pitch,
142182
int32_t surface_offset_x = 0,
143183
int32_t surface_offset_y = 0) {
184+
this->base_ptr = (mem_dtype)p;
185+
this->surface_width = surface_width;
186+
this->surface_height = surface_height;
187+
this->surface_pitch = surface_pitch;
188+
this->offset_x = surface_offset_x;
189+
this->offset_y = surface_offset_y;
190+
144191
xetla_tdescriptor base_tdesc;
145192
xetla_fill_tdesc(
146193
base_tdesc.xetla_format<uint32_t>(),
@@ -159,6 +206,13 @@ struct mem_payload_t<
159206
// ~mem_payload_t(){}
160207

161208
inline this_payload_t& operator=(const this_payload_t& rhs) {
209+
this->base_ptr = rhs.base_ptr;
210+
this->surface_width = rhs.surface_width;
211+
this->surface_height = rhs.surface_height;
212+
this->surface_pitch = rhs.surface_pitch;
213+
this->offset_x = rhs.offset_x;
214+
this->offset_y = rhs.offset_y;
215+
162216
this->payloads = rhs.payloads;
163217
return *this;
164218
}

0 commit comments

Comments
 (0)