@@ -50,7 +50,7 @@ struct load_store_attr_t<msg_type::block_2d, gpu_arch::XeHpc> {
5050 static constexpr uint32_t special_prefetch_width_in_bytes = 64 ;
5151
5252 static constexpr uint32_t cache_line_size_in_bytes = 64 ;
53- static constexpr uint32_t alignment_in_bytes = 8 ;
53+ static constexpr uint32_t alignment_in_bytes = 16 ;
5454};
5555
5656template <msg_type message_type, gpu_arch arg_tag>
@@ -72,7 +72,7 @@ struct client_load_store_attr_base_t {
7272 static constexpr uint32_t special_prefetch_width_in_bytes = 64 ;
7373
7474 static constexpr uint32_t cache_line_size_in_bytes = 64 ;
75- static constexpr uint32_t alignment_in_bytes = 8 ;
75+ static constexpr uint32_t alignment_in_bytes = 4 ;
7676};
7777
7878template <>
@@ -94,15 +94,21 @@ inline constexpr bool arch_has_2d_load_store =
9494template <gpu_arch arch_tag>
9595struct load_store_attr_t <msg_type::block_1d, arch_tag> {
9696 static constexpr uint32_t max_load_vec_len = 256 ;
97+ static constexpr uint32_t max_aligned_load_vec_len = 256 ;
9798 static constexpr uint32_t max_store_vec_len = 256 ;
99+ static constexpr uint32_t max_aligned_store_vec_len = 256 ;
98100 static constexpr uint32_t max_prefetch_vec_len = 32 ;
101+ static constexpr uint32_t max_channel_num = 16 ;
99102};
100103
101104template <>
102105struct load_store_attr_t <msg_type::block_1d, gpu_arch::XeHpc> {
103- static constexpr uint32_t max_load_vec_len = 512 ;
104- static constexpr uint32_t max_store_vec_len = 512 ;
106+ static constexpr uint32_t max_load_vec_len = 256 ;
107+ static constexpr uint32_t max_aligned_load_vec_len = 512 ;
108+ static constexpr uint32_t max_store_vec_len = 256 ;
109+ static constexpr uint32_t max_aligned_store_vec_len = 512 ;
105110 static constexpr uint32_t max_prefetch_vec_len = 64 ;
111+ static constexpr uint32_t max_channel_num = 32 ;
106112};
107113
108114struct dpas_attr_base_t {
@@ -112,6 +118,7 @@ struct dpas_attr_base_t {
112118 static constexpr uint32_t op_per_channel_bits = 32 ;
113119 static constexpr uint32_t op_per_channel_bytes = (op_per_channel_bits >> 3 );
114120 static constexpr uint32_t op_per_channel_max = 8 ;
121+ static constexpr uint32_t k_in_bytes = systolic_depth * op_per_channel_bytes;
115122};
116123
117124template <gpu_arch arch_tag>
@@ -121,12 +128,12 @@ struct dpas_attr_t {
121128
122129template <>
123130struct dpas_attr_t <gpu_arch::XeHpc> : public dpas_attr_base_t {
124- static constexpr uint32_t n_fixed_limit = 16 ;
131+ static constexpr uint32_t n_in_elem = 16 ;
125132};
126133
127134template <>
128135struct dpas_attr_t <gpu_arch::XeHpg> : public dpas_attr_base_t {
129- static constexpr uint32_t n_fixed_limit = 8 ;
136+ static constexpr uint32_t n_in_elem = 8 ;
130137};
131138
132139template <gpu_arch arch_tag>
@@ -140,9 +147,10 @@ struct fpu_attr_t {
140147template <gpu_arch arch_tag>
141148inline constexpr bool arch_has_fpu = fpu_attr_t <arch_tag>::has_fpu;
142149
143- #define GRF grf_mode::double_grf
144150#ifdef NORMAL_GRF
145151#define GRF grf_mode::normal_grf
152+ #else
153+ #define GRF grf_mode::double_grf
146154#endif
147155
148156template <grf_mode grf_num_mode>
@@ -155,6 +163,7 @@ struct register_nums_t {
155163
156164template <gpu_arch arch_tag>
157165struct register_bytes_t ;
166+
158167template <>
159168struct register_bytes_t <gpu_arch::XeHpc> {
160169 static constexpr uint32_t reg_in_bytes = 64 ;
@@ -180,24 +189,49 @@ struct register_attr_t {
180189 static constexpr uint32_t grf_in_bytes = register_nums * reg_in_bytes;
181190};
182191
183- template <gpu_arch arch_tag, uint32_t m, class enable = void >
192+ template <
193+ gpu_arch arch_tag,
194+ mma_engine engine_type,
195+ uint32_t m,
196+ class enable = void >
184197struct mma_attr_t {};
185198
186199template <gpu_arch arch_tag, uint32_t m>
187- struct mma_attr_t <arch_tag, m, std::enable_if_t <arch_has_xmx<arch_tag>>> {
200+ struct mma_attr_t <
201+ arch_tag,
202+ mma_engine::xmx,
203+ m,
204+ std::enable_if_t <arch_has_xmx<arch_tag>>> {
188205 using dpas_attr = dpas_attr_t <arch_tag>;
206+ using load_store_attr = load_store_attr_t <msg_type::block_2d, arch_tag>;
189207 static constexpr uint32_t mma_m_in_elem =
190208 (m > dpas_attr::rcount_max) ? dpas_attr::rcount_max : m;
191- static constexpr uint32_t mma_n_in_elem = dpas_attr::n_fixed_limit;
192- static constexpr uint32_t mma_k_in_bytes =
193- dpas_attr::systolic_depth * dpas_attr::op_per_channel_bytes;
209+ static constexpr uint32_t blk_m_in_elem = 16 ;
210+
211+ static constexpr uint32_t mma_n_in_elem = dpas_attr::n_in_elem;
212+ [[maybe_unused]] static constexpr uint32_t blk_n_in_bytes =
213+ load_store_attr::max_trans_load_width_in_bytes;
214+
215+ static constexpr uint32_t mma_k_in_bytes = dpas_attr::k_in_bytes;
216+ static constexpr uint32_t blk_k_in_bytes = mma_k_in_bytes;
194217};
195218
196219template <gpu_arch arch_tag, uint32_t m>
197- struct mma_attr_t <arch_tag, m, std::enable_if_t <!arch_has_xmx<arch_tag>>> {
220+ struct mma_attr_t <
221+ arch_tag,
222+ mma_engine::fpu,
223+ m,
224+ std::enable_if_t <arch_has_fpu<arch_tag>>> {
225+ using load_store_attr = load_store_attr_t <msg_type::block_2d, arch_tag>;
198226 static constexpr uint32_t mma_m_in_elem = (m > 8 ) ? 8 : m;
199- static constexpr uint32_t mma_n_in_elem = 16 ;
227+ static constexpr uint32_t blk_m_in_elem = 16 ;
228+
200229 static constexpr uint32_t mma_k_in_bytes = 32 ;
230+ static constexpr uint32_t blk_k_in_bytes = mma_k_in_bytes;
231+
232+ [[maybe_unused]] static constexpr uint32_t mma_n_in_elem = 16 ;
233+ static constexpr uint32_t blk_n_in_bytes =
234+ register_bytes_t <arch_tag>::reg_in_bytes;
201235};
202236
203237template <gpu_arch arch_tag>
@@ -208,43 +242,51 @@ struct arch_attr_t<gpu_arch::XeHpc> {
208242 template <msg_type message_type = msg_type::block_2d>
209243 using load_store_attr = load_store_attr_t <message_type, gpu_arch::XeHpc>;
210244
211- template <grf_mode grf_num_mode = grf_mode::double_grf >
245+ template <grf_mode grf_num_mode = GRF >
212246 using register_attr = register_attr_t <grf_num_mode, gpu_arch::XeHpc>;
213247
214248 using dpas_attr = dpas_attr_t <gpu_arch::XeHpc>;
215249
216250 static constexpr uint32_t max_wg_num = 64 ;
217251 static constexpr uint32_t local_mem_size = 128 * 1024 ;
252+ static constexpr bool has_named_barrier = true ;
218253};
219254
220255template <>
221256struct arch_attr_t <gpu_arch::XeHpg> {
222257 template <msg_type message_type = msg_type::block_2d>
223258 using load_store_attr = load_store_attr_t <message_type, gpu_arch::XeHpg>;
224259
225- template <grf_mode grf_num_mode = grf_mode::double_grf >
260+ template <grf_mode grf_num_mode = GRF >
226261 using register_attr = register_attr_t <grf_num_mode, gpu_arch::XeHpg>;
227262
228263 using dpas_attr = dpas_attr_t <gpu_arch::XeHpg>;
229264
230- static constexpr uint32_t max_wg_num = 64 ;
265+ static constexpr uint32_t max_wg_num = 32 ;
231266 static constexpr uint32_t local_mem_size = 64 * 1024 ;
267+
268+ static constexpr bool has_named_barrier = false ;
232269};
233270
234271template <>
235272struct arch_attr_t <gpu_arch::XeLpg> {
236273 template <msg_type message_type = msg_type::block_2d>
237274 using load_store_attr = load_store_attr_t <message_type, gpu_arch::XeLpg>;
238275
239- template <grf_mode grf_num_mode = grf_mode::double_grf >
276+ template <grf_mode grf_num_mode = GRF >
240277 using register_attr = register_attr_t <grf_num_mode, gpu_arch::XeLpg>;
241278
242279 using dpas_attr = dpas_attr_t <gpu_arch::XeLpg>;
243280
244- static constexpr uint32_t max_wg_num = 64 ;
281+ static constexpr uint32_t max_wg_num = 32 ;
245282 static constexpr uint32_t local_mem_size = 64 * 1024 ;
283+ static constexpr bool has_named_barrier = false ;
246284};
247285
286+ template <gpu_arch arch_tag>
287+ inline constexpr bool arch_has_named_barrier =
288+ arch_attr_t <arch_tag>::has_named_barrier;
289+
248290// / @} xetla_core_arch_config
249291
250292} // namespace gpu::xetla
0 commit comments