@@ -620,8 +620,12 @@ class fmha_forward_t {
620620 mem_desc_Dp_Mask_t::layout,
621621 mem_desc_Dp_Mask_t::space>,
622622 dp_mask_tile_desc_t ,
623- subgroup::
624- msg_type_v<dp_mask_tile_desc_t , mem_desc_Dp_Mask_t::space>,
623+ subgroup::msg_type_v<
624+ dp_mask_tile_desc_t ,
625+ mem_desc_t <
626+ uint8_t ,
627+ mem_desc_Dp_Mask_t::layout,
628+ mem_desc_Dp_Mask_t::space>>,
625629 gpu_arch::XeHpc>;
626630 load_payload_mask_t load_payload_mask (ctx.mem_desc_Dpij );
627631 subgroup::tile_load (mask_in, load_payload_mask);
@@ -722,7 +726,12 @@ class fmha_forward_t {
722726 using matOi_store_t = subgroup::mem_payload_t <
723727 mem_desc_t <scalar_t , mem_desc_Oi_t::layout, mem_desc_Oi_t::space>,
724728 matOi_tile_desc_t,
725- subgroup::msg_type_v<matOi_tile_desc_t, mem_desc_Oi_t::space>,
729+ subgroup::msg_type_v<
730+ matOi_tile_desc_t,
731+ mem_desc_t <
732+ scalar_t ,
733+ mem_desc_Oi_t::layout,
734+ mem_desc_Oi_t::space>>,
726735 arch_tag>;
727736 matOi_store_t matOi_store (mem_desc_Oi);
728737 subgroup::tile_store<cache_hint::write_back, cache_hint::write_back>(
@@ -762,12 +771,19 @@ class fmha_forward_t {
762771 using matQi_load_t = subgroup::mem_payload_t <
763772 mem_desc_t <scalar_t , mem_desc_Qi_t::layout, mem_desc_Qi_t::space>,
764773 matQi_tile_desc_t,
765- subgroup::msg_type_v<matQi_tile_desc_t, mem_desc_Qi_t::space>,
774+ subgroup::msg_type_v<
775+ matQi_tile_desc_t,
776+ mem_desc_t <scalar_t , mem_desc_Qi_t::layout, mem_desc_Qi_t::space>>,
766777 arch_tag>;
767778 using matQi_store_t = subgroup::mem_payload_t <
768779 mem_desc_t <scalar_t , mem_desc_Qi_L_t::layout, mem_desc_Qi_L_t::space>,
769780 matQi_tile_desc_t,
770- subgroup::msg_type_v<matQi_tile_desc_t, mem_desc_Qi_L_t::space>,
781+ subgroup::msg_type_v<
782+ matQi_tile_desc_t,
783+ mem_desc_t <
784+ scalar_t ,
785+ mem_desc_Qi_L_t::layout,
786+ mem_desc_Qi_L_t::space>>,
771787 arch_tag>;
772788
773789 int32_t tile_offset_x = ctx.sg_idx * kSgHm ;
0 commit comments