Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit 193877c

Browse files
committed
channel num ->1 8 16 32
1 parent 4f30651 commit 193877c

File tree

2 files changed

+16
-8
lines changed

2 files changed

+16
-8
lines changed

include/subgroup/tile/impl/payload_xe.hpp

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1138,12 +1138,11 @@ struct mem_payload_t<
11381138
static constexpr uint32_t block_bytes =
11391139
block_size_x * block_size_y * sizeof(dtype);
11401140

1141-
// using mem_dtype = uint32_t;
1142-
11431141
static constexpr uint32_t block_per_row_bytes = std::min(
11441142
(mem_transpose ? block_size_y : block_size_x) * uint32_t(sizeof(dtype)),
11451143
alignment_in_bytes);
11461144

1145+
// using mem_dtype = uint32_t;
11471146
using mem_dtype = typename std::conditional<
11481147
(block_per_row_bytes % sizeof(uint64_t) == 0),
11491148
uint64_t,
@@ -1160,14 +1159,23 @@ struct mem_payload_t<
11601159

11611160
// for pvc, we can use simd16 or simd32
11621161
using load_store_attr = load_store_attr_t<msg_type::block_1d, arch_tag>;
1163-
static constexpr uint32_t max_bytes = load_store_attr::max_load_vec_len;
1162+
static constexpr uint32_t max_bytes =
1163+
std::min(load_store_attr::max_load_vec_len, block_bytes);
11641164

1165-
static constexpr uint32_t simd_channel =
1165+
static constexpr uint32_t max_channel =
11661166
max_bytes / (simd_exec_size * sizeof(mem_dtype));
11671167

1168-
static constexpr uint32_t num_channel = mem_transpose
1169-
? std::min(block_size_x, simd_channel)
1170-
: std::min(block_size_y, simd_channel);
1168+
1169+
static constexpr uint32_t select_channel(const uint32_t channel) {
1170+
return (channel >= 32 && arch_tag == gpu_arch::XeHpc) ? 32
1171+
: channel >= 16 ? 16
1172+
: channel >= 8 ? 8
1173+
: 1;
1174+
}
1175+
1176+
static constexpr uint32_t num_channel = select_channel(
1177+
mem_transpose ? std::min(block_size_x, max_channel)
1178+
: std::min(block_size_y, max_channel));
11711179

11721180
xetla_vector<uint32_t, num_channel> channel_offset;
11731181
xetla_vector<uint32_t, num_channel> step_x;

tests/integration/gemv/int4/main.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class test_col_major_1 {
3737
static constexpr size_t wg_n = 1;
3838
static constexpr size_t sg_m = 1;
3939
static constexpr size_t sg_n = 1;
40-
static constexpr size_t sg_k = 512 / 1;
40+
static constexpr size_t sg_k = 512 / sg_m;
4141
static constexpr size_t dequant_s = 128;
4242
// static constexpr quant_mode quant_mode = quant_mode::S4_ASYM;
4343
static constexpr quant_mode quant_mode = quant_mode::S4_FULLRANGE_NO_ZP;

0 commit comments

Comments
 (0)