Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit 8fc6c57

Browse files
DDElesunjiweiswift
andauthored
XeTLA Zero-Passthrough (#321)
* sync fmha * support pass_thru for 2024.1 * XeTLA use mask with zero-passthrough * reformat --------- Co-authored-by: Sun, Jiwei1 <jiwei1.sun@intel.com>
1 parent 16a9a20 commit 8fc6c57

File tree

5 files changed

+86
-73
lines changed

5 files changed

+86
-73
lines changed

include/common/core/memory.hpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -500,12 +500,23 @@ __XETLA_API xetla_vector<T, N> xetla_load_global(
500500
xetla_vector<OffsetT, N / VS> byte_offsets,
501501
xetla_mask<N / VS> mask,
502502
xetla_vector<T, N> pass_thru) {
503+
#if __INTEL_LLVM_COMPILER >= 20240200
503504
__ESIMD_NS::properties props{
504505
__ESIMD_NS::cache_hint_L1<gpu::xetla::detail::get_cache_hint(L1H)>,
505506
__ESIMD_NS::cache_hint_L2<gpu::xetla::detail::get_cache_hint(L2H)>,
506507
__ESIMD_NS::alignment<alignment>};
507508

508509
return __ESIMD_NS::gather<T, N, VS>(p, byte_offsets, mask, pass_thru, props);
510+
#else
511+
constexpr data_size DS = data_size::default_size;
512+
return __ESIMD_ENS::lsc_gather<
513+
T,
514+
VS,
515+
gpu::xetla::detail::get_data_size(DS),
516+
gpu::xetla::detail::get_cache_hint(L1H),
517+
gpu::xetla::detail::get_cache_hint(L2H),
518+
N / VS>(p, byte_offsets, mask, pass_thru);
519+
#endif
509520
}
510521

511522
/// template <typename T, int N, int VS, typename OffsetT,

include/subgroup/tile/impl/load_xe.hpp

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,7 @@ tile_load(tile_t& tile, payload_t& payload) {
457457
constexpr uint32_t num_channel = payload_t::num_channel;
458458
constexpr uint32_t load_elems = num_channel * payload_t::vector_size;
459459
constexpr uint32_t pack_factor = payload_t::pack_factor;
460+
const xetla_vector<load_dtype, load_elems> reg_zeros(0);
460461

461462
auto channel_offset = payload.channel_offset + payload.base_offset;
462463
#pragma unroll
@@ -494,28 +495,35 @@ tile_load(tile_t& tile, payload_t& payload) {
494495
? (xetla_vector_gen<uint32_t, num_channel>(offset_ch_dim, 1) <
495496
size_ch_dim)
496497
: 1;
498+
reg_tmp = xetla_load_global<
499+
load_dtype,
500+
load_elems,
501+
payload_t::vector_size,
502+
L1,
503+
L2>(
504+
payload.base_ptr,
505+
channel_offset + address_offset,
506+
mask,
507+
reg_zeros);
508+
} else {
509+
reg_tmp = xetla_load_global<
510+
load_dtype,
511+
load_elems,
512+
payload_t::vector_size,
513+
L1,
514+
L2>(payload.base_ptr, channel_offset + address_offset, mask);
497515
}
498-
reg_tmp = xetla_load_global<
499-
load_dtype,
500-
load_elems,
501-
payload_t::vector_size,
502-
L1,
503-
L2>(payload.base_ptr, channel_offset + address_offset, mask);
504516

505517
if constexpr (
506518
payload_t::vector_size > 1 && payload_t::num_channel > 1) {
507519
xetla_vector<load_dtype, load_elems> reg_tmp_trans;
508520
#pragma unroll
509521
for (uint32_t iii = 0; iii < payload_t::num_channel; iii++) {
510-
if ((bool)mask[iii]) // TODO (dingyi): Delete after driver fix
511-
reg_tmp_trans.xetla_select<payload_t::vector_size, 1>(
512-
iii * payload_t::vector_size) =
513-
reg_tmp.xetla_select<
514-
payload_t::vector_size,
515-
payload_t::num_channel>(iii);
516-
else // TODO (dingyi): Delete after driver fix
517-
reg_tmp_trans.xetla_select<payload_t::vector_size, 1>(
518-
iii * payload_t::vector_size) = 0;
522+
reg_tmp_trans.xetla_select<payload_t::vector_size, 1>(
523+
iii * payload_t::vector_size) =
524+
reg_tmp.xetla_select<
525+
payload_t::vector_size,
526+
payload_t::num_channel>(iii);
519527
}
520528
reg_sub
521529
.xetla_select<load_elems * pack_factor, 1>(

include/subgroup/tile/impl/payload_xe.hpp

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1655,12 +1655,11 @@ struct prefetch_payload_t<
16551655
reg_layout_>,
16561656
num_coop_sg_,
16571657
arch_tag_,
1658-
std::enable_if_t<
1659-
(!arch_has_2d_load_store<arch_tag_>) &&
1660-
(((block_size_y_ != 1 || tile_size_y_ != 1) &&
1661-
mem_layout_ == mem_layout::row_major) ||
1662-
((block_size_x_ != 1 || tile_size_x_ != 1) &&
1663-
mem_layout_ == mem_layout::col_major))>> {
1658+
std::enable_if_t<(!arch_has_2d_load_store<arch_tag_>)&&(
1659+
((block_size_y_ != 1 || tile_size_y_ != 1) &&
1660+
mem_layout_ == mem_layout::row_major) ||
1661+
((block_size_x_ != 1 || tile_size_x_ != 1) &&
1662+
mem_layout_ == mem_layout::col_major))>> {
16641663
using dtype = native_type_t<dtype_>;
16651664
using mem_desc_t =
16661665
mem_desc_t<dtype_, mem_layout_, mem_space::global, alignment_>;
@@ -1902,10 +1901,9 @@ struct prefetch_payload_t<
19021901
reg_layout_>,
19031902
num_coop_sg_,
19041903
arch_tag_,
1905-
std::enable_if_t<
1906-
(arch_has_2d_load_store<arch_tag_>) &&
1907-
(((tile_size_y_ != 1) && mem_layout_ == mem_layout::row_major) ||
1908-
((tile_size_x_ != 1) && mem_layout_ == mem_layout::col_major))>> {
1904+
std::enable_if_t<(arch_has_2d_load_store<arch_tag_>)&&(
1905+
((tile_size_y_ != 1) && mem_layout_ == mem_layout::row_major) ||
1906+
((tile_size_x_ != 1) && mem_layout_ == mem_layout::col_major))>> {
19091907
using dtype = dtype_;
19101908
using mem_desc_t =
19111909
mem_desc_t<dtype_, mem_layout_, mem_space::global, alignment_>;

tests/integration/fmha/fmha_forward.hpp

Lines changed: 31 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -129,12 +129,12 @@ class fmha_forward_t {
129129
using comp_attr = group::compute_attr_t<scalar_t, scalar_t, accum_t>;
130130
using knobs = group::perf_tuning_knob_t<accum_step, stages, sync_freq>;
131131
using compute_policy_BrBc = std::conditional_t<
132-
(arch_tag >= gpu_arch::XeHpg),
132+
(arch_has_xmx<arch_tag>),
133133
group::compute_policy_default_xmx<comp_attr, knobs, arch_tag>,
134134
group::compute_policy_default_fpu<comp_attr, knobs, arch_tag>>;
135135
// TODO: add k slicing
136136
using compute_policy_BrBm = std::conditional_t<
137-
(arch_tag >= gpu_arch::XeHpg),
137+
(arch_has_xmx<arch_tag>),
138138
group::compute_policy_default_xmx<comp_attr, knobs, arch_tag>,
139139
group::compute_policy_default_fpu<comp_attr, knobs, arch_tag>>;
140140
// ---------------- // Tile shape and Threads // ---------------- //
@@ -688,7 +688,7 @@ class fmha_forward_t {
688688
uint8_t,
689689
mem_desc_Dp_Mask_t::layout,
690690
mem_desc_Dp_Mask_t::space>>,
691-
gpu_arch::XeHpc>;
691+
arch_tag>;
692692
load_payload_mask_t load_payload_mask(ctx.mem_desc_Dpij);
693693
subgroup::tile_load(mask_in, load_payload_mask);
694694
matAccSij.reg = matAccSij.reg * mask_in.reg * args.dp_scale;
@@ -771,7 +771,7 @@ class fmha_forward_t {
771771
uint32_t height = args.uB * args.uN * args.uF;
772772
uint32_t offset_height = b * args.uN * args.uF + f * args.uN + n;
773773

774-
if constexpr (arch_tag != gpu_arch::XeHpc) {
774+
if constexpr (!arch_has_2d_load_store<arch_tag>) {
775775
// offset for curr work item
776776
const uint32_t O_offset = offset_height * args.uH + h;
777777
const auto ld_c = args.uN * args.uH;
@@ -798,30 +798,30 @@ class fmha_forward_t {
798798
matOi_store_t matOi_store(mem_desc_Oi);
799799
subgroup::tile_store<cache_hint::write_back, cache_hint::write_back>(
800800
matOi, matOi_store);
801-
return;
802-
}
803-
804-
xetla_fill_tdesc<scalar_t, kSgHm, 1, 1>(
805-
transpose_tdecs.xetla_format<uint32_t>(),
806-
args.O_ptr,
807-
args.uH,
808-
height,
809-
args.uH,
810-
h,
811-
offset_height);
812-
813-
for (uint32_t i = 0; i < kSgBr && (f + i < args.uF); ++i) {
814-
// load data from matAccOi
815-
auto v_acc = matAccOi.reg.xetla_select<kSgHm, 1>(i * kSgHm);
816-
v_out = xetla_cvt<scalar_t, accum_t, kSgHm>(v_acc);
817-
818-
xetla_tstore_global<
819-
scalar_t,
820-
kSgHm,
821-
cache_hint::write_back,
822-
cache_hint::write_back>(transpose_tdecs, v_out);
823-
xetla_update_tdesc_offsety(
824-
transpose_tdecs.xetla_format<uint32_t>(), args.uN);
801+
} else {
802+
xetla_fill_tdesc<scalar_t, kSgHm, 1, 1>(
803+
transpose_tdecs.xetla_format<uint32_t>(),
804+
args.O_ptr,
805+
args.uH,
806+
height,
807+
args.uH,
808+
h,
809+
offset_height);
810+
811+
for (uint32_t i = 0; i < kSgBr && (f + i < args.uF); ++i) {
812+
// load data from matAccOi
813+
auto v_acc = matAccOi.reg.xetla_select<kSgHm, 1>(i * kSgHm);
814+
v_out = xetla_cvt<scalar_t, accum_t, kSgHm>(v_acc);
815+
816+
xetla_tstore_global<
817+
scalar_t,
818+
kSgHm,
819+
cache_hint::write_back,
820+
cache_hint::write_back,
821+
arch_tag>(transpose_tdecs, v_out);
822+
xetla_update_tdesc_offsety(
823+
transpose_tdecs.xetla_format<uint32_t>(), args.uN);
824+
}
825825
}
826826
}
827827
// ====================== // preload_Qi // ====================== //
@@ -888,16 +888,9 @@ class fmha_forward_t {
888888
/// @return The size of local memory required.
889889
inline static constexpr uint32_t get_slm_size() {
890890
constexpr uint32_t size = slm_size_Qi + slm_size_Pij + slm_size_softmax;
891-
if constexpr (arch_tag == gpu_arch::XeHpc) {
892-
static_assert(
893-
size <= (128 * 1024),
894-
"The local memory size should be less than 128KB!");
895-
896-
} else {
897-
static_assert(
898-
size <= (64 * 1024),
899-
"The local memory size should be less than 64KB!");
900-
}
891+
static_assert(
892+
size <= (arch_attr_t<arch_tag>::local_mem_size),
893+
"The local memory size should be less than arch total local memory size");
901894
return size;
902895
};
903896

tests/integration/fmha/fmha_utils.h

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ template <
134134
typename mat_t,
135135
uint32_t kNumSg,
136136
reduce_op reduce_kind,
137-
gpu_arch arch_tag = gpu_arch::XeHpc>
137+
gpu_arch arch_tag>
138138
struct group_row_reduce_t {
139139
using T = typename mat_t::dtype;
140140
static constexpr uint32_t kNum = mat_t::tile_desc::tile_size_y;
@@ -215,7 +215,7 @@ enum class add_type : uint8_t {
215215
/// @tparam arch_tag Is the hardware architecture tag.
216216
template <
217217
typename dtype_bias_,
218-
gpu_arch arch_tag = gpu_arch::XeHpc,
218+
gpu_arch arch_tag,
219219
add_type add_tag = add_type::single_line>
220220
struct bias_add_op_t {};
221221

@@ -324,8 +324,8 @@ struct bias_add_op_t<dtype_bias_, arch_tag, add_type::single_element> {
324324
using base_t = typename mem_desc_bias_t::base_t;
325325

326326
struct arguments_t {
327-
shape_t shape;
328327
base_t base;
328+
shape_t shape;
329329
inline arguments_t() = default;
330330
inline arguments_t(base_t base_, shape_t shape_)
331331
: base(base_), shape(shape_) {}
@@ -351,11 +351,10 @@ struct bias_add_op_t<dtype_bias_, arch_tag, add_type::single_element> {
351351
uint32_t offset = (pos_y + pos_x * args.shape.stride) * sizeof(dtype_bias);
352352
auto bias_data_vector = xetla_load_global<
353353
dtype_bias,
354+
16,
354355
1,
355-
data_size::default_size,
356-
cache_hint::cached,
357356
cache_hint::cached,
358-
16>(ptr, offset);
357+
cache_hint::cached>(ptr, offset);
359358
dtype_acc bias_data =
360359
xetla_cvt<dtype_acc, dtype_bias, 16>(bias_data_vector)[0];
361360

@@ -418,15 +417,19 @@ template <
418417
typename mem_desc_c_t_>
419418
class epilogue_transp_t {};
420419

421-
template <typename tile_op_t_, typename tile_shape_, typename mem_desc_c_t_>
420+
template <
421+
typename tile_op_t_,
422+
typename tile_shape_,
423+
typename mem_desc_c_t_,
424+
gpu_arch arch_tag_>
422425
class epilogue_transp_t<
423-
epilogue_policy_tile_op<tile_op_t_, gpu_arch::XeHpc>,
426+
epilogue_policy_tile_op<tile_op_t_, arch_tag_>,
424427
tile_shape_,
425428
mem_desc_c_t_> {
426429
public:
427430
using tile_shape = tile_shape_;
428431
using mem_desc_c_t = mem_desc_c_t_;
429-
static constexpr gpu_arch arch_tag = gpu_arch::XeHpc;
432+
static constexpr gpu_arch arch_tag = arch_tag_;
430433
static constexpr uint32_t barrier_count = 0;
431434
static constexpr uint32_t slm_size = 0;
432435

@@ -505,7 +508,7 @@ class epilogue_write_back_t<
505508
epilogue_policy_default<arch_tag_>,
506509
tile_shape_,
507510
mem_desc_c_t_,
508-
std::enable_if_t<((arch_tag_ <= gpu_arch::XeHpc))>> {
511+
std::enable_if_t<valid_xe_arch_tag<arch_tag_>>> {
509512
public:
510513
using epilogue_policy = epilogue_policy_default<arch_tag_>;
511514
using tile_shape = tile_shape_;

0 commit comments

Comments
 (0)