Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit c39cdfc

Browse files
committed
fix GRU & int4gemm
1 parent d5dd203 commit c39cdfc

File tree

5 files changed

+63
-42
lines changed

5 files changed

+63
-42
lines changed

include/common/core/arch_config.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ struct load_store_attr_t<msg_type::block_2d, gpu_arch::XeHpc> {
6565
// BlockWidth * NBlocks must not exceed 64 for bytes, 32 for words, 16 for
6666
// dwords, and 8 for qwords.
6767
static constexpr uint32_t max_load_size_in_bytes = 2048;
68-
68+
6969
// BlockWidth * BlockHeight * sizeof(T) must not exceed 512.
7070
static constexpr uint32_t max_store_size_in_bytes = 512;
7171

include/common/core/memory.hpp

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -458,16 +458,38 @@ __XETLA_API xetla_vector<T, N> xetla_load_global(
458458
size_t SurfacePitch,
459459
int X,
460460
int Y) {
461-
return __ESIMD_ENS::lsc_load_2d<
462-
T,
463-
BlockWidth,
464-
BlockHeight,
465-
NBlocks,
466-
Transposed,
467-
Transformed,
468-
gpu::xetla::detail::get_cache_hint(L1H),
469-
gpu::xetla::detail::get_cache_hint(L2H),
470-
N>(Ptr, SurfaceWidth, SurfaceHeight, SurfacePitch, X, Y);
461+
if constexpr (BlockWidth * sizeof(T) < sizeof(uint32_t)) {
462+
constexpr auto scale_factor = sizeof(uint32_t) / sizeof(T);
463+
xetla_vector<uint32_t, N> ret = __ESIMD_ENS::lsc_load_2d<
464+
uint32_t,
465+
BlockWidth,
466+
BlockHeight,
467+
NBlocks,
468+
Transposed,
469+
Transformed,
470+
gpu::xetla::detail::get_cache_hint(L1H),
471+
gpu::xetla::detail::get_cache_hint(L2H),
472+
N>(
473+
reinterpret_cast<const uint32_t*>(Ptr),
474+
SurfaceWidth,
475+
SurfaceHeight,
476+
SurfacePitch,
477+
X / scale_factor,
478+
Y);
479+
return ret.xetla_format<T>().xetla_select<N, scale_factor>(
480+
X % scale_factor);
481+
} else {
482+
return __ESIMD_ENS::lsc_load_2d<
483+
T,
484+
BlockWidth,
485+
BlockHeight,
486+
NBlocks,
487+
Transposed,
488+
Transformed,
489+
gpu::xetla::detail::get_cache_hint(L1H),
490+
gpu::xetla::detail::get_cache_hint(L2H),
491+
N>(Ptr, SurfaceWidth, SurfaceHeight, SurfacePitch, X, Y);
492+
}
471493
}
472494

473495
/// simd<T, N> block_load(const T* ptr, size_t byte_offset,

include/subgroup/tile/impl/load_xe.hpp

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -122,27 +122,27 @@ tile_load(tile_t& tile, payload_t& payload) {
122122
mem_transpose ? max_trans_block_width : max_load_block_height;
123123
static constexpr uint32_t ld_blk_size_y = reg_transpose
124124
? block_size_y
125-
: std::min(ld_blk_size_y_limit, block_size_y)
126-
127-
// array len is used to make sure memory load is cache line aligned
128-
// disabled while register or memory transpose
129-
static constexpr uint8_t arr_len_candidate =
130-
(reg_transpose ||
131-
mem_transpose
132-
// block elements should be integer
133-
// times of register bytes
134-
|| ((block_size_y * block_size_x) % elems_per_reg != 0)
135-
// tail blocks also need to meet above condition
136-
|| (((tile_size_y % block_size_y) * block_size_x) % elems_per_reg !=
137-
0)) ||
138-
(block_size_y > ld_blk_size_y_limit)
139-
? 1
140-
: (((tile_size_x % elems_per_CL) == 0)
141-
? (((elems_per_CL % block_size_x) == 0)
142-
? elems_per_CL / block_size_x
143-
: 1)
144-
: ((tile_size_x < elems_per_CL) ? (tile_size_x / block_size_x)
145-
: 1));
125+
: std::min(ld_blk_size_y_limit, block_size_y);
126+
127+
// array len is used to make sure memory load is cache line aligned
128+
// disabled while register or memory transpose
129+
static constexpr uint8_t arr_len_candidate =
130+
(reg_transpose ||
131+
mem_transpose
132+
// block elements should be integer
133+
// times of register bytes
134+
|| ((block_size_y * block_size_x) % elems_per_reg != 0)
135+
// tail blocks also need to meet above condition
136+
||
137+
(((tile_size_y % block_size_y) * block_size_x) % elems_per_reg != 0)) ||
138+
(block_size_y > ld_blk_size_y_limit)
139+
? 1
140+
: (((tile_size_x % elems_per_CL) == 0)
141+
? (((elems_per_CL % block_size_x) == 0)
142+
? elems_per_CL / block_size_x
143+
: 1)
144+
: ((tile_size_x < elems_per_CL) ? (tile_size_x / block_size_x)
145+
: 1));
146146
static constexpr bool is_valid_arr_len_candidate = (arr_len_candidate == 1) ||
147147
(arr_len_candidate == 2) || (arr_len_candidate == 4);
148148

@@ -213,16 +213,16 @@ tile_load(tile_t& tile, payload_t& payload) {
213213
// mem_transform,
214214
// arch_tag>(tdesc);
215215
reg_tmp.xetla_format<native_type_t<load_dtype>>() = xetla_load_global<
216-
load_dtype,
216+
native_type_t<load_dtype>,
217217
block_size_x / scale_factor,
218218
block_size_y,
219-
num_block,
219+
arr_len,
220220
trans,
221221
mem_transform,
222222
L1,
223223
L2>(
224-
(load_dtype*)::gpu::xetla::detail::xetla_get_tensor_base_address(
225-
tdesc),
224+
(native_type_t<load_dtype>*)::gpu::xetla::detail::
225+
xetla_get_tensor_base_address(tdesc),
226226
::gpu::xetla::detail::xetla_get_tensor_width_x(tdesc),
227227
::gpu::xetla::detail::xetla_get_tensor_width_y(tdesc),
228228
::gpu::xetla::detail::xetla_get_tensor_pitch_x(tdesc),

include/subgroup/tile/impl/payload_xe.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ struct mem_payload_t<
129129
uint32_t surface_pitch,
130130
int32_t surface_offset_x = 0,
131131
int32_t surface_offset_y = 0) {
132-
this->base_ptr = (mem_dtype)p;
132+
this->base_ptr = (mem_dtype*)p;
133133
this->surface_width = surface_width;
134134
this->surface_height = surface_height;
135135
this->surface_pitch = surface_pitch;

include/subgroup/tile/impl/store_xe.hpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -161,18 +161,17 @@ tile_store(tile_t& tile, payload_t& payload) {
161161
for (uint32_t ii = 0; ii < block_size_y / st_block_size_y; ++ii) {
162162
constexpr uint32_t store_elems =
163163
st_block_size_y * block_size_x * arr_len;
164-
auto st_blk =
164+
xetla_vector<dtype, store_elems> st_blk =
165165
combine_blk.xetla_select<store_elems, 1>(ii * store_elems);
166166
// xetla_tstore_global<dtype, store_elems, L1, L2, payload_t::arch_tag>(
167167
// tdesc, st_blk);
168168
xetla_store_global<
169169
dtype,
170-
block_size_x,
171-
block_size_y,
172-
num_block,
170+
block_size_x * arr_len,
171+
st_block_size_y,
173172
L1,
174173
L2>(
175-
::gpu::xetla::detail::xetla_get_tensor_base_address(tdesc),
174+
(dtype*)::gpu::xetla::detail::xetla_get_tensor_base_address(tdesc),
176175
::gpu::xetla::detail::xetla_get_tensor_width_x(tdesc),
177176
::gpu::xetla::detail::xetla_get_tensor_width_y(tdesc),
178177
::gpu::xetla::detail::xetla_get_tensor_pitch_x(tdesc),

0 commit comments

Comments
 (0)