1313 * See the License for the specific language governing permissions and
1414 * limitations under the License.
1515 *******************************************************************************/
16- #include < tests/utils/utils.hpp>
1716#include " xetla.hpp"
17+ #include < tests/utils/utils.hpp>
1818
19- template <gpu_arch arch_tag_ >
19+ template <gpu_arch arch_tag >
2020void basic_gemm_run (sycl::queue queue, uint32_t iter) {
2121 // Tips, the example demonstrates programming kernel with XeTLA, it works as expected with current configurations.
2222 // Please make sure you fully understand these configurations before you do any modifications, incomplete changes may lead to unexpected behaviors.
@@ -110,11 +110,11 @@ void basic_gemm_run(sycl::queue queue, uint32_t iter) {
110110 // should larger than 8
111111 static constexpr uint32_t k_stride = 32 ;
112112
113- // Step 1: define mirco -kernel's configuration
113+ // Step 1: define Micro -kernel's configuration
114114 using wg_shape = shape<wg_tile_n, wg_tile_m>;
115115 using sg_shape = shape<sg_tile_n, sg_tile_m>;
116116
117- // Mirco -kernel configuration
117+ // Micro -kernel configuration
118118 using gemm_tune_option
119119 = dict_t <elem_t_t <tune_key::sg_tile_shape, sg_shape>,
120120 elem_v_t <tune_key::prefetch_distance,
@@ -132,10 +132,10 @@ void basic_gemm_run(sycl::queue queue, uint32_t iter) {
132132 8 , // leading dimension for B, in unit of element
133133 mem_space::
134134 global, // memory reading from global mem for B
135- data_type_acc, // accumulator data type for intermediate resutls
135+ data_type_acc, // accumulator data type for intermediate results
136136 wg_shape, // computation tile shape
137137 k_stride, // elements in each iteration
138- gpu_arch::Xe , // GPU arch
138+ arch_tag , // GPU arch
139139 gemm_tune_option>;
140140 gemm_t gemm;
141141
@@ -149,24 +149,26 @@ void basic_gemm_run(sycl::queue queue, uint32_t iter) {
149149 mem_space::global, // memory writing to global mem for C
150150 wg_shape, // computation tile shape
151151 k_stride, // elements in each iteration
152- gpu_arch::Xe , // GPU arch
152+ arch_tag , // GPU arch
153153 epilogue_tune_option>;
154154
155155 // Step 3: define the shared local memory usages
156156 // developers have the responsibility to set
157- // shared loacal memory through XeTLA API
157+ // shared local memory through XeTLA API
158158 static constexpr uint32_t barrier_count = gemm_t ::barrier_count;
159159 static constexpr uint32_t slm_size = gemm_t ::slm_size;
160+ static_assert (slm_size <= arch_attr_t <arch_tag>::local_mem_size,
161+ " The local memory size excess!" );
160162 xetla_nbarrier_init<barrier_count>();
161163 xetla_local_init<slm_size>();
162164
163- // Step 4: ecah workgroup gets it individual index to start computation
165+ // Step 4: each workgroup gets it individual index to start computation
164166 int start_n = item.get_group (2 ) * wg_tile_n;
165167 int start_m = item.get_group (1 ) * wg_tile_m;
166168 // no slicing in K direction so start from zero for all WG
167169 int start_k = 0 ;
168170
169- // Each workgroup will compute all data in K based on no k_sliciing
171+ // Each workgroup will compute all data in K based on no k_slicing
170172 // The developer can set how much data a subgroup compute by k_stride
171173 uint32_t wg_tile_k = matrix_k;
172174 uint32_t inner_loop_count
@@ -183,7 +185,7 @@ void basic_gemm_run(sycl::queue queue, uint32_t iter) {
183185 mem_desc_output_c md_c (
184186 {C}, {matrix_n, matrix_m, ldc}, {start_n, start_m});
185187
186- // Step 6: real calculation with accumulator varibales which suppose
188+ // Step 6: real calculation with accumulator variables which suppose
187189 // will be in register.
188190 typename gemm_t ::matAcc_t matAcc;
189191 matAcc.init (0 );
@@ -194,8 +196,7 @@ void basic_gemm_run(sycl::queue queue, uint32_t iter) {
194196 // the results is in the matAcc rather than real output C
195197 typename gemm_t ::work_group_t g (item.get_local_linear_id ());
196198 gemm (g, matAcc, gemm_args);
197-
198- // Step 7: write the results from matACC to real output C
199+ // Step 7: write the results from matAcc to real output C
199200 epilogue_t epilogue;
200201 epilogue (g, matAcc, md_c);
201202 });
@@ -220,23 +221,21 @@ void basic_gemm_run(sycl::queue queue, uint32_t iter) {
220221 free (C, context);
221222}
222223
224+ template <gpu_arch arch_tag>
225+ struct main_wrapper {
226+ static constexpr auto exec = []() {
227+ // This case shows how to use batch-reduce (br) GEMM microkernel to
228+ // solve a standard GEMM
229+ // Turn on the profiling property to facilitate subsequent profiling
230+ sycl::property_list properties {
231+ sycl::property::queue::enable_profiling ()};
232+
233+ // Define SYCL queue, context and device
234+ auto queue = sycl::queue (properties);
235+ basic_gemm_run<arch_tag>(queue, 10 );
236+ };
237+ };
223238int main () {
224- // This case shows how to use batch-reduce (br) GEMM microkernel to
225- // solve a standard GEMM
226- // Turn on the profiling property to facilitate subsequent profiling
227- sycl::property_list properties {sycl::property::queue::enable_profiling ()};
228-
229- // Define SYCL queue, context and device
230- auto queue = sycl::queue (properties);
231- auto device = queue.get_device ();
232-
233- // Detect the execution size, 8 for Arc, 16 for PVC.
234- int ExecSize
235- = device.get_info <ext::intel::info::device::gpu_eu_simd_width>();
236- if (ExecSize == 8 ) {
237- basic_gemm_run<gpu_arch::Dg2>(queue, 10 );
238- } else {
239- basic_gemm_run<gpu_arch::Xe>(queue, 10 );
240- }
241- return (0 );
239+ dispatch_arch<main_wrapper>::exec ();
240+ return 0 ;
242241}
0 commit comments