Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit 8407143

Browse files
committed
add xmx colmajor
fix
1 parent 5066c1c commit 8407143

File tree

4 files changed

+83
-52
lines changed

4 files changed

+83
-52
lines changed

include/experimental/group/gemm/impl/int4_dequantize_xe.hpp

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,12 @@ class gemm_t<
156156
: is_vnni_tiled_a ? reg_layout::vnni_tiled
157157
: reg_layout::tiled;
158158

159+
// reg_layout of the load result
159160
static constexpr reg_layout reg_layout_b =
161+
is_col_major_b ? reg_layout::transpose_tiled : reg_layout::tiled;
162+
163+
// reg_layout required by mma
164+
static constexpr reg_layout reg_layout_b_acc =
160165
// fpu
161166
compute_policy::mma_engine == mma_engine::fpu
162167
? (is_gemv ? reg_layout::transpose_tiled : reg_layout::tiled)
@@ -214,7 +219,7 @@ class gemm_t<
214219
tile_size_y_b,
215220
block_size_x_b,
216221
block_size_y_b,
217-
reg_layout_b>;
222+
reg_layout_b_acc>;
218223
using matB_acc_t = subgroup::tile_t<dtype_mma_b, matB_acc_tile_desc_t>;
219224

220225
public:
@@ -629,9 +634,10 @@ class gemm_t<
629634
matA_acc,
630635
i == args.inner_loop_count - 1);
631636
} else {
632-
if constexpr (is_col_major_b) {
633-
tile_transpose(matB_acc);
634-
}
637+
// The result of dequantize should always be (plain) tiled
638+
if constexpr (
639+
matB_acc_tile_desc_t::register_layout == reg_layout::vnni_tiled)
640+
subgroup::vnni_convert(matB_acc);
635641
tile_mma::mma(matC, matC, matB_acc, matA_acc);
636642
}
637643
if constexpr (enable_periodic_sync) {
@@ -696,9 +702,10 @@ class gemm_t<
696702
tile_mma::mma(
697703
matAcc, matAcc, matC, matB_acc, matA_acc, i == compute_stages - 1);
698704
} else {
699-
if constexpr (is_col_major_b) {
700-
tile_transpose(matB_acc);
701-
}
705+
// The result of dequantize should always be (plain) tiled
706+
if constexpr (
707+
matB_acc_tile_desc_t::register_layout == reg_layout::vnni_tiled)
708+
subgroup::vnni_convert(matB_acc);
702709
tile_mma::mma(matC, matC, matB_acc, matA_acc);
703710
}
704711
if constexpr (enable_periodic_sync) {

include/subgroup/tile/impl/op_function.hpp

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -704,36 +704,37 @@ layout_convert(T_dst& dst, T_src& src) {
704704
}
705705
}
706706

707-
template <typename T>
708-
void dump_mat(
709-
T mat,
710-
size_t tile_x = T::reg_transpose ? T::tile_size_y : T::tile_size_x,
711-
size_t tile_y = T::reg_transpose ? T::tile_size_x : T::tile_size_y) {
712-
#pragma unroll
713-
for (size_t row = 0; row < tile_y; row++) {
714-
#pragma unroll
715-
for (size_t col = 0; col < tile_x; col++) {
716-
sycl::ext::oneapi::experimental::printf(
717-
"%x(%d) ",
718-
int(native_type_t<typename T::dtype>(mat.reg[row * tile_x + col])),
719-
int(native_type_t<typename T::dtype>(mat.reg[row * tile_x + col])));
720-
}
721-
sycl::ext::oneapi::experimental::printf("\n");
722-
}
723-
sycl::ext::oneapi::experimental::printf("\n ");
724-
}
725707
template <typename T>
726708
void dump_mat_reg(T mat, size_t tile_x, size_t tile_y) {
727709
#pragma unroll
728710
for (size_t row = 0; row < tile_y; row++) {
729711
#pragma unroll
730712
for (size_t col = 0; col < tile_x; col++) {
731-
sycl::ext::oneapi::experimental::printf(
732-
"%d ", (int)(sycl::half)mat[row * tile_x + col]);
713+
const auto&& v = int64_t(
714+
native_type_t<typename T::element_type>(mat[row * tile_x + col]));
715+
constexpr bool is_int32 =
716+
(std::is_same<typename T::element_type, int4x2>::value ||
717+
std::is_same<typename T::element_type, int4x8>::value ||
718+
std::is_same<typename T::element_type, uint32_t>::value ||
719+
std::is_same<typename T::element_type, int32_t>::value);
720+
constexpr bool is_int64 =
721+
(std::is_same<typename T::element_type, uint64_t>::value ||
722+
std::is_same<typename T::element_type, int64_t>::value);
723+
is_int32 ? sycl::ext::oneapi::experimental::printf(
724+
"%08x(%10u) ", int(v), int(v))
725+
: is_int64
726+
? sycl::ext::oneapi::experimental::printf("%016llx(%20llu) ", v, v)
727+
: sycl::ext::oneapi::experimental::printf("%3lld ", v);
733728
}
734729
sycl::ext::oneapi::experimental::printf("\n");
735730
}
736731
sycl::ext::oneapi::experimental::printf("\n");
737732
}
738-
733+
template <typename T>
734+
void dump_mat(
735+
T mat,
736+
size_t tile_x = T::reg_transpose ? T::tile_size_y : T::tile_size_x,
737+
size_t tile_y = T::reg_transpose ? T::tile_size_x : T::tile_size_y) {
738+
dump_mat_reg(mat.reg, tile_x, tile_y);
739+
}
739740
} // namespace gpu::xetla::subgroup

include/subgroup/tile/impl/tile_op_functor.hpp

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,12 @@ struct dequant_int4_weight_t {
8383
constexpr uint32_t block_size_y_b = matB_acc_t::block_size_y;
8484
static constexpr uint32_t pack_ratio = sizeof(typename matB_t::dtype) * 2;
8585

86+
// If the result of dequant should be tranposed before storing to matB_acc
87+
constexpr bool trans_acc =
88+
matB_t::register_layout == reg_layout::transpose_tiled &&
89+
(matB_acc_t::register_layout == reg_layout::tiled ||
90+
matB_acc_t::register_layout == reg_layout::vnni_tiled);
91+
8692
constexpr uint32_t num_block_x = tile_size_x_b / block_size_x_b;
8793
constexpr uint32_t num_block_y = tile_size_y_b / block_size_y_b;
8894
#pragma unroll
@@ -149,9 +155,18 @@ struct dequant_int4_weight_t {
149155
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) -
150156
int8_t(8);
151157
}
152-
dst_blk.xetla_select<step, 1>(jj * block_size_y_b + ii) =
153-
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) *
154-
scale.reg[scale_idx];
158+
// Scale and write back to matB_acc
159+
if constexpr (trans_acc) {
160+
dst_blk.xetla_select<step, block_size_x_b>(
161+
ii * block_size_x_b + jj) =
162+
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) *
163+
scale.reg[scale_idx];
164+
165+
} else {
166+
dst_blk.xetla_select<step, 1>(jj * block_size_y_b + ii) =
167+
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) *
168+
scale.reg[scale_idx];
169+
}
155170

156171
// sycl::ext::oneapi::experimental::printf(
157172
// "scale[%d] %f \n",

tests/integration/gemv/int4/main.cpp

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,14 @@ template <typename scalar_t>
3131
class test_col_major_1 {
3232
public:
3333
// Extract the parameters required by different test cases
34-
static constexpr size_t mat_m = 1;
34+
static constexpr size_t mat_m = 4096;
3535
static constexpr size_t mat_n = 4096;
3636
static constexpr size_t mat_k = 4096;
37-
static constexpr size_t wg_m = 1;
38-
static constexpr size_t wg_n = 1;
39-
static constexpr size_t sg_m = 1;
40-
static constexpr size_t sg_n = 1;
41-
static constexpr size_t sg_k = 512 / sg_m;
37+
static constexpr size_t wg_m = 64;
38+
static constexpr size_t wg_n = 32;
39+
static constexpr size_t sg_m = 16;
40+
static constexpr size_t sg_n = 8;
41+
static constexpr size_t sg_k = 32;
4242
static constexpr size_t dequant_s = 128;
4343
static constexpr quant_mode quant_mode = quant_mode::I4_SYM;
4444

@@ -109,14 +109,17 @@ int gemm_result_validate(
109109
bool result = buff_cmp::xetla_buff_cmp(data, other, "gemv validation");
110110

111111
#ifdef UT_DEBUG
112-
// for (uint32_t i = 0; i < m; i++) {
113-
// for (uint32_t j = 0; j < n; j++) {
114-
// std::cout << float(sycl::half(C[i * n + j])) << " ";
115-
// }
116-
// std::cout << std::endl;
117-
// }
112+
if (m * n <= 4096) {
113+
std::cout << "result:\n";
114+
for (uint32_t i = 0; i < m; i++) {
115+
for (uint32_t j = 0; j < n; j++) {
116+
std::cout << float(sycl::half(C[i * n + j])) << " ";
117+
}
118+
std::cout << "\n";
119+
}
120+
}
118121
#endif
119-
std::cout << (!result ? "FAILED\n" : "PASSED\n");
122+
std::cout << (!result ? "FAILED" : "PASSED") << std::endl;
120123
return result ? 0 : 1;
121124
}
122125

@@ -186,12 +189,15 @@ std::vector<data_type_acc_in> dequantize_weight(
186189
}
187190
}
188191
#ifdef UT_DEBUG
189-
// for (uint32_t i = 0; i < matrix_n; i++) {
190-
// for (uint32_t j = 0; j < matrix_k; j++) {
191-
// std::cout << float(sycl::half(b_out[i * matrix_k + j])) << " ";
192-
// }
193-
// std::cout << std::endl;
194-
// }
192+
if (matrix_n * matrix_k <= 4096) {
193+
std::cout << "dequantize_weight:\n";
194+
for (uint32_t i = 0; i < matrix_n; i++) {
195+
for (uint32_t j = 0; j < matrix_k; j++) {
196+
std::cout << float(sycl::half(b_out[i * matrix_k + j])) << " ";
197+
}
198+
std::cout << std::endl;
199+
}
200+
}
195201
#endif
196202
return b_out;
197203
}
@@ -386,12 +392,14 @@ void dequantize_gemv_run(int iter) {
386392
if constexpr (std::is_same_v<int4x2, data_type_b>) {
387393
B_h[i] = random_uint8();
388394
#ifdef UT_DEBUG
389-
B_h[i] = 0x77;
395+
B_h[i] = ((7 + i) % 15 + 1) * 0x11;
396+
if (i >= size_b)
397+
B_h[i] = -1;
390398
#endif
391399
} else if constexpr (std::is_same_v<int4x8, data_type_b>) {
392400
B_h[i] = random_uint32();
393401
#ifdef UT_DEBUG
394-
B_h[i] = 0x77777777;
402+
B_h[i] = ((7 + i) % 15 + 1) * 0x11111111;
395403
#endif
396404
}
397405
}

0 commit comments

Comments
 (0)