@@ -27,11 +27,14 @@ namespace gpu::xetla {
2727// / @{
2828
2929template <msg_type message_type, gpu_arch arch_tag>
30- struct load_store_attr_t {};
30+ struct load_store_attr_t {
31+ static constexpr bool has_hw_block_2d = false ;
32+ };
3133
3234template <>
3335struct load_store_attr_t <msg_type::block_2d, gpu_arch::XeHpc> {
3436 // / HW limitation checks https://gfxspecs.intel.com/Predator/Home/Index/55490
37+ static constexpr bool has_hw_block_2d = true ;
3538 static constexpr uint32_t max_load_height_in_elem = 32 ;
3639 static constexpr uint32_t max_load_width_in_bytes = 64 ;
3740 static constexpr uint32_t max_trans_load_width_in_bytes = 32 ;
@@ -53,6 +56,7 @@ struct load_store_attr_t<msg_type::block_2d, gpu_arch::XeHpc> {
5356template <msg_type message_type, gpu_arch arg_tag>
5457struct client_load_store_attr_base_t {
5558 // / HW limitation checks https://gfxspecs.intel.com/Predator/Home/Index/55490
59+ static constexpr bool has_hw_block_2d = false ;
5660 static constexpr uint32_t max_load_height_in_elem = 32 ;
5761 static constexpr uint32_t max_load_width_in_bytes = 64 ;
5862 static constexpr uint32_t max_trans_load_width_in_bytes = 32 ;
@@ -83,74 +87,116 @@ struct load_store_attr_t<msg_type::block_2d, gpu_arch::XeLpg>
8387 msg_type::block_2d,
8488 gpu_arch::XeLpg> {};
8589
90+ template <gpu_arch arch_tag>
91+ inline constexpr bool arch_has_2d_load_store =
92+ load_store_attr_t <msg_type::block_2d, arch_tag>::has_hw_block_2d;
93+
8694template <gpu_arch arch_tag>
8795struct load_store_attr_t <msg_type::block_1d, arch_tag> {
96+ static constexpr uint32_t max_load_vec_len = 32 ;
97+ static constexpr uint32_t max_store_vec_len = 32 ;
98+ static constexpr uint32_t max_prefetch_vec_len = 32 ;
99+ };
100+
101+ template <>
102+ struct load_store_attr_t <msg_type::block_1d, gpu_arch::XeHpc> {
88103 static constexpr uint32_t max_load_vec_len = 64 ;
89104 static constexpr uint32_t max_store_vec_len = 64 ;
105+ static constexpr uint32_t max_prefetch_vec_len = 64 ;
90106};
91107
92- template <gpu_arch arch_tag>
93- struct mma_attr_t {};
108+ struct dpas_attr_base_t {
109+ static constexpr bool has_xmx = true ;
110+ static constexpr uint32_t systolic_depth = 8 ;
111+ static constexpr uint32_t rcount_max = 8 ;
112+ static constexpr uint32_t op_per_channel_bits = 32 ;
113+ static constexpr uint32_t op_per_channel_bytes = (op_per_channel_bits >> 3 );
114+ static constexpr uint32_t op_per_channel_max = 8 ;
115+ };
94116
95117template <gpu_arch arch_tag>
96- struct client_mma_atr_base_t {
97- static constexpr uint32_t mma_m_in_elem = 8 ;
98- static constexpr uint32_t mma_n_in_elem = 8 ;
99- static constexpr uint32_t mma_k_in_bytes = 32 ;
118+ struct dpas_attr_t {
119+ static constexpr bool has_xmx = false ;
100120};
101121
102122template <>
103- struct mma_attr_t <gpu_arch::XeHpc> {
104- static constexpr uint32_t mma_m_in_elem = 8 ;
105- static constexpr uint32_t mma_n_in_elem = 16 ;
106- static constexpr uint32_t mma_k_in_bytes = 32 ;
123+ struct dpas_attr_t <gpu_arch::XeHpc> : public dpas_attr_base_t {
124+ static constexpr uint32_t n_fixed_limit = 16 ;
107125};
108126
109127template <>
110- struct mma_attr_t <gpu_arch::XeHpg>
111- : public client_mma_atr_base_t <gpu_arch::XeHpg> {};
128+ struct dpas_attr_t <gpu_arch::XeHpg> : public dpas_attr_base_t {
129+ static constexpr uint32_t n_fixed_limit = 8 ;
130+ };
112131
113- template <grf_mode grf_num_mode, gpu_arch arch_tag>
114- struct register_attr_t {} ;
132+ template <gpu_arch arch_tag>
133+ inline constexpr bool arch_has_xmx = dpas_attr_t <arch_tag>::has_xmx ;
115134
116- template <grf_mode grf_num_mode, gpu_arch arch_tag>
117- struct client_register_attr_base_t {
118- static constexpr uint32_t acc_reg_in_bytes =
119- (grf_num_mode == grf_mode::normal) ? 4 * 64 : 8 * 64 ;
120- static constexpr uint32_t grf_in_bytes =
121- (grf_num_mode == grf_mode::normal) ? 128 * 64 : 256 * 64 ;
122- static constexpr uint32_t reg_in_bytes = 64 ;
135+ template <gpu_arch arch_tag>
136+ struct fpu_attr_t {
137+ static constexpr bool has_fpu = true ;
123138};
124139
140+ template <gpu_arch arch_tag>
141+ inline constexpr bool arch_has_fpu = fpu_attr_t <arch_tag>::has_fpu;
142+
125143template <grf_mode grf_num_mode>
126- struct register_attr_t <grf_num_mode, gpu_arch::XeHpc> {
127- static constexpr uint32_t acc_reg_in_bytes =
128- (grf_num_mode == grf_mode::normal) ? 4 * 64 : 8 * 64 ;
129- static constexpr uint32_t grf_in_bytes =
130- (grf_num_mode == grf_mode::normal) ? 128 * 64 : 256 * 64 ;
144+ struct register_nums_t {
145+ static constexpr uint32_t register_nums =
146+ (grf_num_mode == grf_mode::normal) ? 128 : 256 ;
147+ static constexpr uint32_t acc_register_nums =
148+ (grf_num_mode == grf_mode::normal) ? 4 : 8 ;
149+ };
150+
151+ template <gpu_arch arch_tag>
152+ struct register_bytes_t {
131153 static constexpr uint32_t reg_in_bytes = 64 ;
132154};
133155
134- template <grf_mode grf_num_mode>
135- struct register_attr_t <grf_num_mode, gpu_arch::XeHpg>
136- : public client_register_attr_base_t <grf_num_mode, gpu_arch::XeHpg> {};
156+ template <grf_mode grf_num_mode, gpu_arch arch_tag>
157+ struct register_attr_t {
158+ static constexpr uint32_t reg_in_bytes =
159+ register_bytes_t <arch_tag>::reg_in_bytes;
160+ static constexpr uint32_t register_nums =
161+ register_nums_t <grf_num_mode>::register_nums;
162+ static constexpr uint32_t acc_register_nums =
163+ register_nums_t <grf_num_mode>::acc_register_nums;
164+ static constexpr uint32_t acc_reg_in_bytes = acc_register_nums * reg_in_bytes;
165+ static constexpr uint32_t grf_in_bytes = register_nums * reg_in_bytes;
166+ };
137167
138- template <grf_mode grf_num_mode>
139- struct register_attr_t <grf_num_mode, gpu_arch::XeLpg>
140- : public client_register_attr_base_t <grf_num_mode, gpu_arch::XeLpg> {};
168+ template <gpu_arch arch_tag, uint32_t m, class enable = void >
169+ struct mma_attr_t {};
170+
171+ template <gpu_arch arch_tag, uint32_t m>
172+ struct mma_attr_t <arch_tag, m, std::enable_if_t <arch_has_xmx<arch_tag>>> {
173+ using dpas_attr = dpas_attr_t <arch_tag>;
174+ static constexpr uint32_t mma_m_in_elem =
175+ (m > dpas_attr::rcount_max) ? dpas_attr::rcount_max : m;
176+ static constexpr uint32_t mma_n_in_elem = dpas_attr::n_fixed_limit;
177+ static constexpr uint32_t mma_k_in_bytes =
178+ dpas_attr::systolic_depth * dpas_attr::op_per_channel_bytes;
179+ };
180+
181+ template <gpu_arch arch_tag, uint32_t m>
182+ struct mma_attr_t <arch_tag, m, std::enable_if_t <!arch_has_xmx<arch_tag>>> {
183+ static constexpr uint32_t mma_m_in_elem = (m > 8 ) ? 8 : m;
184+ static constexpr uint32_t mma_n_in_elem = 16 ;
185+ static constexpr uint32_t mma_k_in_bytes = 32 ;
186+ };
141187
142188template <gpu_arch arch_tag>
143189struct arch_attr_t {};
144190
145191template <gpu_arch arch_tag>
146192struct client_arch_attr_base_t {
147193 template <msg_type message_type = msg_type::block_2d>
148- using load_store_attr = load_store_attr_t <message_type, gpu_arch::XeHpg >;
194+ using load_store_attr = load_store_attr_t <message_type, arch_tag >;
149195
150- template <grf_mode grf_num_mode = grf_mode::double_grf >
151- using register_attr = register_attr_t <grf_num_mode, gpu_arch::XeHpg >;
196+ template <grf_mode grf_num_mode = grf_mode::normal >
197+ using register_attr = register_attr_t <grf_num_mode, arch_tag >;
152198
153- using mma_attr = mma_attr_t <gpu_arch::XeHpg >;
199+ using dpas_attr = dpas_attr_t <arch_tag >;
154200
155201 static constexpr uint32_t max_wg_num = 64 ;
156202 static constexpr uint32_t local_mem_size = 64 * 1024 ;
@@ -164,7 +210,7 @@ struct arch_attr_t<gpu_arch::XeHpc> {
164210 template <grf_mode grf_num_mode = grf_mode::double_grf>
165211 using register_attr = register_attr_t <grf_num_mode, gpu_arch::XeHpc>;
166212
167- using mma_attr = mma_attr_t <gpu_arch::XeHpc>;
213+ using dpas_attr = dpas_attr_t <gpu_arch::XeHpc>;
168214
169215 static constexpr uint32_t max_wg_num = 64 ;
170216 static constexpr uint32_t local_mem_size = 128 * 1024 ;
0 commit comments