Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit dd44636

Browse files
committed
add xmx colmajor
1 parent 3043b5a commit dd44636

File tree

6 files changed

+126
-104
lines changed

6 files changed

+126
-104
lines changed

include/experimental/group/gemm/impl/int4_dequantize_xe.hpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,9 @@ class gemm_t<
157157
: reg_layout::tiled;
158158

159159
static constexpr reg_layout reg_layout_b =
160+
is_col_major_b ? reg_layout::transpose_tiled : reg_layout::tiled;
161+
162+
static constexpr reg_layout reg_layout_b_acc =
160163
// fpu
161164
compute_policy::mma_engine == mma_engine::fpu
162165
? (is_gemv ? reg_layout::transpose_tiled : reg_layout::tiled)
@@ -214,7 +217,7 @@ class gemm_t<
214217
tile_size_y_b,
215218
block_size_x_b,
216219
block_size_y_b,
217-
reg_layout_b>;
220+
reg_layout_b_acc>;
218221
using matB_acc_t = subgroup::tile_t<dtype_mma_b, matB_acc_tile_desc_t>;
219222

220223
public:
@@ -635,9 +638,9 @@ class gemm_t<
635638
matA_acc,
636639
i == args.inner_loop_count - 1);
637640
} else {
638-
if constexpr (is_col_major_b) {
639-
tile_transpose(matB_acc);
640-
}
641+
if constexpr (
642+
matB_acc_tile_desc_t::register_layout == reg_layout::vnni_tiled)
643+
subgroup::vnni_convert(matB_acc);
641644
tile_mma::mma(matC, matC, matB_acc, matA_acc);
642645
}
643646
SW_BARRIER();

include/subgroup/tile/impl/load_xe.hpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -486,14 +486,13 @@ tile_load(tile_t& tile, payload_t& payload) {
486486
(payload_t::mem_transpose ? sub_block_offset : 0);
487487
const uint32_t sub_block_offset_y = payload.base_y + offset_y +
488488
(payload_t::mem_transpose ? 0 : sub_block_offset);
489-
const auto offset_ch_dim =
490-
payload_t::trans ? sub_block_offset_x : sub_block_offset_y;
491-
const auto size_ch_dim = payload_t::trans ? payload.width_in_elems
492-
: payload.height_in_elems;
489+
const auto offset_ch_dim = payload_t::mem_transpose
490+
? sub_block_offset_x
491+
: sub_block_offset_y;
493492

494-
pred = offset_ch_dim + num_channel > size_ch_dim
493+
pred = offset_ch_dim + num_channel > payload.height_in_elems
495494
? (xetla_vector_gen<uint32_t, num_channel>(offset_ch_dim, 1) <
496-
size_ch_dim)
495+
payload.height_in_elems)
497496
: 1;
498497
}
499498
reg_tmp = xetla_load_global<

include/subgroup/tile/impl/op_function.hpp

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -703,36 +703,34 @@ layout_convert(T_dst& dst, T_src& src) {
703703
}
704704
}
705705

706-
template <typename T>
707-
void dump_mat(
708-
T mat,
709-
size_t tile_x = T::reg_transpose ? T::tile_size_y : T::tile_size_x,
710-
size_t tile_y = T::reg_transpose ? T::tile_size_x : T::tile_size_y) {
711-
#pragma unroll
712-
for (size_t row = 0; row < tile_y; row++) {
713-
#pragma unroll
714-
for (size_t col = 0; col < tile_x; col++) {
715-
sycl::ext::oneapi::experimental::printf(
716-
"%x(%d) ",
717-
int(native_type_t<typename T::dtype>(mat.reg[row * tile_x + col])),
718-
int(native_type_t<typename T::dtype>(mat.reg[row * tile_x + col])));
719-
}
720-
sycl::ext::oneapi::experimental::printf("\n");
721-
}
722-
sycl::ext::oneapi::experimental::printf("\n ");
723-
}
724706
template <typename T>
725707
void dump_mat_reg(T mat, size_t tile_x, size_t tile_y) {
726708
#pragma unroll
727709
for (size_t row = 0; row < tile_y; row++) {
728710
#pragma unroll
729711
for (size_t col = 0; col < tile_x; col++) {
730-
sycl::ext::oneapi::experimental::printf(
731-
"%d ", (int)(sycl::half)mat[row * tile_x + col]);
712+
const auto&& v = int64_t(
713+
native_type_t<typename T::element_type>(mat[row * tile_x + col]));
714+
(std::is_same<typename T::element_type, int4x2>::value ||
715+
std::is_same<typename T::element_type, int4x8>::value ||
716+
std::is_same<typename T::element_type, uint32_t>::value ||
717+
std::is_same<typename T::element_type, int32_t>::value)
718+
? sycl::ext::oneapi::experimental::printf(
719+
"%08x(%10u) ", int(v), int(v))
720+
: (std::is_same<typename T::element_type, uint64_t>::value ||
721+
std::is_same<typename T::element_type, int64_t>::value)
722+
? sycl::ext::oneapi::experimental::printf("%016llx(%20llu) ", v, v)
723+
: sycl::ext::oneapi::experimental::printf("%3lld ", v);
732724
}
733725
sycl::ext::oneapi::experimental::printf("\n");
734726
}
735727
sycl::ext::oneapi::experimental::printf("\n");
736728
}
737-
729+
template <typename T>
730+
void dump_mat(
731+
T mat,
732+
size_t tile_x = T::reg_transpose ? T::tile_size_y : T::tile_size_x,
733+
size_t tile_y = T::reg_transpose ? T::tile_size_x : T::tile_size_y) {
734+
dump_mat_reg(mat.reg, tile_x, tile_y);
735+
}
738736
} // namespace gpu::xetla::subgroup

include/subgroup/tile/impl/payload_xe.hpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -441,8 +441,8 @@ struct mem_payload_t<
441441

442442
inline mem_payload_t(mem_desc_t& mem_tdesc) {
443443
pitch_in_bytes = mem_tdesc.shape.stride * sizeof(dtype);
444-
width_in_elems = mem_tdesc.shape.x;
445-
height_in_elems = mem_tdesc.shape.y;
444+
width_in_elems = mem_transpose ? mem_tdesc.shape.y : mem_tdesc.shape.x;
445+
height_in_elems = mem_transpose ? mem_tdesc.shape.x : mem_tdesc.shape.y;
446446
payload_bytes = mem_transpose ? (mem_tdesc.shape.x - 1) * pitch_in_bytes +
447447
mem_tdesc.shape.y * sizeof(dtype)
448448
: (mem_tdesc.shape.y - 1) * pitch_in_bytes +
@@ -481,8 +481,8 @@ struct mem_payload_t<
481481
pitch_in_bytes = mem_tdesc.shape.stride * sizeof(dtype);
482482
uint32_t offset_x = mem_tdesc.coord.x;
483483
uint32_t offset_y = mem_tdesc.coord.y;
484-
width_in_elems = mem_tdesc.shape.x;
485-
height_in_elems = mem_tdesc.shape.y;
484+
width_in_elems = mem_transpose ? mem_tdesc.shape.y : mem_tdesc.shape.x;
485+
height_in_elems = mem_transpose ? mem_tdesc.shape.x : mem_tdesc.shape.y;
486486
payload_bytes = mem_transpose ? (mem_tdesc.shape.x - 1) * pitch_in_bytes +
487487
mem_tdesc.shape.y * sizeof(dtype)
488488
: (mem_tdesc.shape.y - 1) * pitch_in_bytes +
@@ -950,8 +950,8 @@ struct mem_payload_t<
950950
pitch_in_bytes = mem_tdesc.shape.stride * sizeof(dtype);
951951
base_x = mem_tdesc.coord.x;
952952
base_y = mem_tdesc.coord.y;
953-
width_in_elems = mem_tdesc.shape.x;
954-
height_in_elems = mem_tdesc.shape.y;
953+
width_in_elems = mem_transpose ? mem_tdesc.shape.y : mem_tdesc.shape.x;
954+
height_in_elems = mem_transpose ? mem_tdesc.shape.x : mem_tdesc.shape.y;
955955
base_offset = mem_transpose
956956
? base_x * pitch_in_bytes + base_y * sizeof(dtype)
957957
: base_y * pitch_in_bytes + base_x * sizeof(dtype);
@@ -996,8 +996,8 @@ struct mem_payload_t<
996996
pitch_in_bytes = mem_tdesc.shape.stride * sizeof(dtype);
997997
base_x = mem_tdesc.coord.x;
998998
base_y = mem_tdesc.coord.y;
999-
width_in_elems = mem_tdesc.shape.x;
1000-
height_in_elems = mem_tdesc.shape.y;
999+
width_in_elems = mem_transpose ? mem_tdesc.shape.y : mem_tdesc.shape.x;
1000+
height_in_elems = mem_transpose ? mem_tdesc.shape.x : mem_tdesc.shape.y;
10011001
base_offset = mem_transpose
10021002
? base_x * pitch_in_bytes + base_y * sizeof(dtype)
10031003
: base_y * pitch_in_bytes + base_x * sizeof(dtype);
@@ -1193,8 +1193,8 @@ struct mem_payload_t<
11931193
pitch_in_bytes = mem_tdesc.shape.stride * sizeof(dtype);
11941194
base_x = mem_tdesc.coord.x;
11951195
base_y = mem_tdesc.coord.y;
1196-
width_in_elems = mem_tdesc.shape.x;
1197-
height_in_elems = mem_tdesc.shape.y;
1196+
width_in_elems = mem_transpose ? mem_tdesc.shape.y : mem_tdesc.shape.x;
1197+
height_in_elems = mem_transpose ? mem_tdesc.shape.x : mem_tdesc.shape.y;
11981198
base_offset = mem_transpose
11991199
? base_x * pitch_in_bytes + base_y * sizeof(dtype)
12001200
: base_y * pitch_in_bytes + base_x * sizeof(dtype);
@@ -1232,8 +1232,8 @@ struct mem_payload_t<
12321232
base_x = mem_tdesc.coord.x;
12331233
base_y = mem_tdesc.coord.y;
12341234

1235-
width_in_elems = mem_tdesc.shape.x;
1236-
height_in_elems = mem_tdesc.shape.y;
1235+
width_in_elems = mem_transpose ? mem_tdesc.shape.y : mem_tdesc.shape.x;
1236+
height_in_elems = mem_transpose ? mem_tdesc.shape.x : mem_tdesc.shape.y;
12371237
base_offset = mem_transpose
12381238
? base_x * pitch_in_bytes + base_y * sizeof(dtype)
12391239
: base_y * pitch_in_bytes + base_x * sizeof(dtype);

include/subgroup/tile/impl/tile_op_functor.hpp

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,11 @@ struct dequant_int4_weight_t {
8888
constexpr uint32_t block_size_y_b = matB_acc_t::block_size_y;
8989
static constexpr uint32_t pack_ratio = sizeof(typename matB_t::dtype) * 2;
9090

91+
constexpr bool trans_acc =
92+
matB_t::register_layout == reg_layout::transpose_tiled &&
93+
(matB_acc_t::register_layout == reg_layout::tiled ||
94+
matB_acc_t::register_layout == reg_layout::vnni_tiled);
95+
9196
constexpr uint32_t num_block_x = tile_size_x_b / block_size_x_b;
9297
constexpr uint32_t num_block_y = tile_size_y_b / block_size_y_b;
9398
#pragma unroll
@@ -154,9 +159,18 @@ struct dequant_int4_weight_t {
154159
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) -
155160
int8_t(8);
156161
}
157-
dst_blk.xetla_select<step, 1>(jj * block_size_y_b + ii) =
158-
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) *
159-
scale.reg[scale_idx];
162+
// Scale and write back to matB_acc
163+
if constexpr (trans_acc) {
164+
dst_blk.xetla_select<step, block_size_x_b>(
165+
ii * block_size_x_b + jj) =
166+
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) *
167+
scale.reg[scale_idx];
168+
169+
} else {
170+
dst_blk.xetla_select<step, 1>(jj * block_size_y_b + ii) =
171+
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) *
172+
scale.reg[scale_idx];
173+
}
160174

161175
// sycl::ext::oneapi::experimental::printf(
162176
// "scale[%d] %f \n",

tests/integration/gemv/int4/main.cpp

Lines changed: 66 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,14 @@ constexpr size_t UNDEFINED_DATA_SIZE = 1024;
3030
class test_col_major_1 {
3131
public:
3232
// Extract the parameters required by different test cases
33-
static constexpr size_t mat_m = 1;
33+
static constexpr size_t mat_m = 4096;
3434
static constexpr size_t mat_n = 4096;
3535
static constexpr size_t mat_k = 4096;
36-
static constexpr size_t wg_m = 1;
37-
static constexpr size_t wg_n = 1;
38-
static constexpr size_t sg_m = 1;
39-
static constexpr size_t sg_n = 1;
40-
static constexpr size_t sg_k = 512 / sg_m;
36+
static constexpr size_t wg_m = 64;
37+
static constexpr size_t wg_n = 32;
38+
static constexpr size_t sg_m = 16;
39+
static constexpr size_t sg_n = 8;
40+
static constexpr size_t sg_k = 32;
4141
static constexpr size_t dequant_s = 128;
4242
// static constexpr quant_mode quant_mode = quant_mode::S4_ASYM;
4343
static constexpr quant_mode quant_mode = quant_mode::S4_FULLRANGE_NO_ZP;
@@ -46,8 +46,8 @@ class test_col_major_1 {
4646
static constexpr size_t global_kslicing = 1;
4747
static constexpr mem_layout layout_a = mem_layout::row_major;
4848
static constexpr mem_layout layout_b = mem_layout::col_major;
49-
static constexpr mma_engine mma_eng = mma_engine::fpu;
50-
static constexpr gpu_arch arch = gpu_arch::XeLpg;
49+
static constexpr mma_engine mma_eng = mma_engine::xmx;
50+
static constexpr gpu_arch arch = gpu_arch::XeHpg;
5151
using data_type_a = fp16;
5252
using data_type_b = int4x8;
5353
using data_type_c = fp16;
@@ -108,14 +108,17 @@ int gemm_result_validate(
108108
bool result = buff_cmp::xetla_buff_cmp(data, other, "gemv validation");
109109

110110
#ifdef UT_DEBUG
111-
// for (uint32_t i = 0; i < m; i++) {
112-
// for (uint32_t j = 0; j < n; j++) {
113-
// std::cout << float(sycl::half(C[i * n + j])) << " ";
114-
// }
115-
// std::cout << std::endl;
116-
// }
111+
if (m * n <= 4096) {
112+
std::cout << "result:\n";
113+
for (uint32_t i = 0; i < m; i++) {
114+
for (uint32_t j = 0; j < n; j++) {
115+
std::cout << float(sycl::half(C[i * n + j])) << " ";
116+
}
117+
std::cout << "\n";
118+
}
119+
}
117120
#endif
118-
std::cout << (!result ? "FAILED\n" : "PASSED\n");
121+
std::cout << (!result ? "FAILED" : "PASSED") << std::endl;
119122
return result ? 0 : 1;
120123
}
121124

@@ -185,12 +188,15 @@ std::vector<data_type_acc_in> dequantize_weight(
185188
}
186189
}
187190
#ifdef UT_DEBUG
188-
// for (uint32_t i = 0; i < matrix_n; i++) {
189-
// for (uint32_t j = 0; j < matrix_k; j++) {
190-
// std::cout << float(sycl::half(b_out[i * matrix_k + j])) << " ";
191-
// }
192-
// std::cout << std::endl;
193-
// }
191+
if (matrix_n * matrix_k <= 4096) {
192+
std::cout << "dequantize_weight:\n";
193+
for (uint32_t i = 0; i < matrix_n; i++) {
194+
for (uint32_t j = 0; j < matrix_k; j++) {
195+
std::cout << float(sycl::half(b_out[i * matrix_k + j])) << " ";
196+
}
197+
std::cout << std::endl;
198+
}
199+
}
194200
#endif
195201
return b_out;
196202
}
@@ -385,12 +391,14 @@ void dequantize_gemv_run(int iter) {
385391
if constexpr (std::is_same_v<int4x2, data_type_b>) {
386392
B_h[i] = random_uint8();
387393
#ifdef UT_DEBUG
388-
B_h[i] = 0x77;
394+
B_h[i] = ((7 + i) % 15 + 1) * 0x11;
395+
if (i >= size_b)
396+
B_h[i] = -1;
389397
#endif
390398
} else if constexpr (std::is_same_v<int4x8, data_type_b>) {
391399
B_h[i] = random_uint32();
392400
#ifdef UT_DEBUG
393-
B_h[i] = 0x77777777;
401+
B_h[i] = ((7 + i) % 15 + 1) * 0x11111111;
394402
#endif
395403
}
396404
}
@@ -473,43 +481,43 @@ void dequantize_gemv_run(int iter) {
473481
{// epilogue_args init list
474482
// It accepts the base pointer to matrix D, and its dimensions
475483
{bias_d, bias_add_shape}});
476-
typename gemm_op_t::template arguments_t<compute_policy::quant_mode> gemm_arg;
484+
using gemm_arg_t =
485+
typename gemm_op_t::template arguments_t<compute_policy::quant_mode>;
486+
gemm_arg_t gemm_arg;
477487
if constexpr (compute_policy::quant_mode == quant_mode::S4_FULLRANGE_NO_ZP) {
478-
gemm_arg =
479-
typename gemm_op_t::template arguments_t<compute_policy::quant_mode>(
480-
matrix_m,
481-
matrix_k,
482-
matrix_n,
483-
A_d,
484-
lda,
485-
B_d,
486-
ldb,
487-
C_d,
488-
ldc,
489-
scale_d,
490-
ld_scale,
491-
Acc_d,
492-
Cnt_d,
493-
epilogue_args);
488+
gemm_arg = gemm_arg_t(
489+
matrix_m,
490+
matrix_k,
491+
matrix_n,
492+
A_d,
493+
lda,
494+
B_d,
495+
ldb,
496+
C_d,
497+
ldc,
498+
scale_d,
499+
ld_scale,
500+
Acc_d,
501+
Cnt_d,
502+
epilogue_args);
494503
} else if constexpr (compute_policy::quant_mode == quant_mode::S4_ASYM) {
495-
gemm_arg =
496-
typename gemm_op_t::template arguments_t<compute_policy::quant_mode>(
497-
matrix_m,
498-
matrix_k,
499-
matrix_n,
500-
A_d,
501-
lda,
502-
B_d,
503-
ldb,
504-
C_d,
505-
ldc,
506-
scale_d,
507-
ld_scale,
508-
zero_pt_d,
509-
ld_zero_pt,
510-
Acc_d,
511-
Cnt_d,
512-
epilogue_args);
504+
gemm_arg = gemm_arg_t(
505+
matrix_m,
506+
matrix_k,
507+
matrix_n,
508+
A_d,
509+
lda,
510+
B_d,
511+
ldb,
512+
C_d,
513+
ldc,
514+
scale_d,
515+
ld_scale,
516+
zero_pt_d,
517+
ld_zero_pt,
518+
Acc_d,
519+
Cnt_d,
520+
epilogue_args);
513521
}
514522
cl::sycl::nd_range<3> nd_range = gemm_op_t::get_nd_range(gemm_arg);
515523
// if (!gemm_op_t::can_implement(gemm_arg)) {

0 commit comments

Comments
 (0)