Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit 8817f54

Browse files
committed
fix XEHPC 2D load
1 parent b2dfad5 commit 8817f54

File tree

2 files changed

+6
-7
lines changed

2 files changed

+6
-7
lines changed

include/experimental/group/gemm/impl/int4_dequantize_xe.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ class gemm_t<
210210
typename mem_desc_b_t::dtype,
211211
mem_layout::row_major,
212212
mem_desc_b_t::space>>,
213-
// subgroup::msg_type_v<matB_tile_desc_t, mem_desc_b_t>,
213+
// subgroup::msg_type_v<matB_tile_desc_t, mem_desc_b_t>,
214214
arch_tag>;
215215
using matB_prefetch_payload_t = subgroup::
216216
prefetch_payload_t<mem_desc_b_t, matB_tile_desc_t, wg_size_y, arch_tag>;

include/subgroup/tile/impl/payload_xe.hpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,15 +65,14 @@ struct mem_payload_t<
6565
mem_payload_t<mem_desc_t, tile_desc, msg_type::block_2d, arch_tag>;
6666

6767
public:
68-
static constexpr bool mem_transpose =
69-
memory_layout == mem_layout::col_major &&
70-
!(std::is_same_v<dtype_, int4x2> || std::is_same_v<dtype_, int4x8>);
68+
static constexpr bool mem_transpose = memory_layout == mem_layout::col_major;
7169

7270
static constexpr reg_layout register_layout = tile_desc::register_layout;
7371
static constexpr bool reg_transpose =
7472
register_layout == reg_layout::transpose_tiled;
7573

76-
static constexpr bool trans = mem_transpose ^ reg_transpose;
74+
static constexpr bool trans = (mem_transpose ^ reg_transpose) &&
75+
!(std::is_same_v<dtype_, int4x2> || std::is_same_v<dtype_, int4x8>);
7776

7877
static constexpr bool mem_transform = (sizeof(dtype) < 4) && !mem_transpose &&
7978
(register_layout == reg_layout::vnni_tiled ||
@@ -1094,7 +1093,7 @@ struct mem_payload_t<
10941093
static constexpr reg_layout register_layout = tile_desc::register_layout;
10951094
static constexpr bool reg_transpose =
10961095
register_layout == reg_layout::transpose_tiled;
1097-
static constexpr bool trans = mem_transpose ^ reg_transpose &&
1096+
static constexpr bool trans = (mem_transpose ^ reg_transpose) &&
10981097
!(std::is_same_v<dtype_, int4x2> || std::is_same_v<dtype_, int4x8>);
10991098

11001099
static constexpr bool mem_transform = (sizeof(dtype) < 4) &&
@@ -1657,7 +1656,7 @@ struct prefetch_payload_t<
16571656
static constexpr reg_layout register_layout = tile_desc::register_layout;
16581657
static constexpr bool reg_transpose =
16591658
register_layout == reg_layout::transpose_tiled;
1660-
static constexpr bool trans = mem_transpose ^ reg_transpose &&
1659+
static constexpr bool trans = (mem_transpose ^ reg_transpose) &&
16611660
!(std::is_same_v<dtype_, int4x2> || std::is_same_v<dtype_, int4x8>);
16621661

16631662
using prefetch_dtype = typename std::conditional<

0 commit comments

Comments
 (0)