Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit ebf5e71

Browse files
DDEleairMeng
andauthored
DG2 supports for XeTLA SDP example (#159)
--------- Co-authored-by: Meng, Hengyu <hengyu.meng@intel.com>
1 parent 5b0327d commit ebf5e71

File tree

43 files changed

+633
-428
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+633
-428
lines changed

.editorconfig

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# EditorConfig is awesome: https://EditorConfig.org
2+
3+
# top-most EditorConfig file
4+
root = true
5+
6+
# Unix-style newlines with a newline ending every file
7+
[*]
8+
end_of_line = lf
9+
insert_final_newline = true
10+
trim_trailing_whitespace = true
11+
12+
# C/C++ follows clang-format
13+
[*.{c,cpp,h,hpp}]
14+
indent_style = space
15+
indent_size = 4

CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@ if (NOT CMAKE_BUILD_TYPE)
99
endif()
1010
if(UNIX)
1111
else() # Windows
12-
# Force CMake to use icx-cl rather than the default C++ compiler/linker
12+
# Force CMake to use icx-cl rather than the default C++ compiler/linker
1313
# (needed on Windows only)
1414
# include (CMakeForceCompiler)
1515
# CMAKE_FORCE_CXX_COMPILER (icx-cl IntelDPCPP)
1616
set(CMAKE_CXX_COMPILER icx-cl)
1717
include (Platform/Windows-Clang)
1818
include(cmake/GTestExternal.cmake)
19-
endif()
19+
endif()
2020

2121
project(XeTLA)
2222

examples/01_gemm_universal/gemm_universal.cpp

Lines changed: 43 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,13 @@
1313
* See the License for the specific language governing permissions and
1414
* limitations under the License.
1515
*******************************************************************************/
16-
#include <tests/utils/utils.hpp>
1716
#include "xetla.hpp"
17+
#include <tests/utils/utils.hpp>
1818

1919
enum class kslicing_impl_t : uint8_t { none = 0, global = 1, local = 2 };
2020

21-
template <kslicing_impl_t kslicing_type = kslicing_impl_t::none>
21+
template <gpu_arch arch_tag,
22+
kslicing_impl_t kslicing_type = kslicing_impl_t::none>
2223
void gemm_universal_run(uint32_t iter) {
2324
// Tips, the example demonstrates programming kernel with XeTLA, it works as expected with current configurations.
2425
// Please make sure you fully understand these configurations before you do any modifications, incomplete changes may lead to unexpected behaviors.
@@ -82,7 +83,7 @@ void gemm_universal_run(uint32_t iter) {
8283
constexpr uint32_t num_local_splitk
8384
= (kslicing_type == kslicing_impl_t::local) ? 2 : 1;
8485

85-
// Mirco-kernel configuration
86+
// Micro-kernel configuration
8687
using tune_option = dict_t<
8788
elem_v_t<tune_key::param_optimizer_type,
8889
tune_key_value::param_optimizer_decision_tree>,
@@ -102,8 +103,8 @@ void gemm_universal_run(uint32_t iter) {
102103
data_type_c, // output datatype for C
103104
mem_layout::row_major, // memory layout for C
104105
8, // leading dimension alignment for C, in unit of element
105-
data_type_acc, // accumulator data type for intermediate resutls
106-
gpu_arch::Xe, // GPU arch
106+
data_type_acc, // accumulator data type for intermediate results
107+
arch_tag, // GPU arch
107108
tune_option>;
108109

109110
// allocate temp buffers for global split
@@ -184,36 +185,42 @@ void gemm_universal_run(uint32_t iter) {
184185
free(Cnt, context);
185186
}
186187

188+
template <gpu_arch arch_tag>
189+
struct main_wrapper {
190+
static constexpr auto exec = []() {
191+
// An example code for calculating matrix multiplication using
192+
// GEMM_UNIVERSAL API:
193+
// C = A x B
194+
// The resulted matrix C is partitioned by the group range
195+
// in to multiple blocks. The block matrix
196+
// C<i_w, j_w>
197+
// is computed by the workgroup with id: (0, i_w, j_w).
198+
// (i_w, j_w) is an element in range specified by group range.
199+
// Each thread with index (0, i_s, j_s) inside the same workgroup
200+
// is responsible for a sub block of matrix multiplication, which is
201+
// C<i_w, j_w>[i_s*sg_m:(i_s+1):sg_m,j_s*sg_n:(j_s+1)*sg_n]
202+
203+
// Alternatively, some threads can cooperate on the same sub block
204+
// matrix given the same (i_s, j_s), i.e. the index space is extended
205+
// from (0, i_s, j_s) to (k_s, i_s, j_s).
206+
207+
// Another method to achieve the same effect is to extend the index space
208+
// in group range, i.e. from (0, i_w, j_w) to (k_w, i_w, j_w)
209+
210+
// More detailed description referring to the cooperation (kslicing) could
211+
// be found in the example 01_gemm_universal with custom implementation
212+
213+
// basic gemm_universal
214+
gemm_universal_run<arch_tag, kslicing_impl_t::none>(10);
215+
216+
// basic gemm_universal with workgroup cooperation
217+
// gemm_universal_run<arch_tag, kslicing_impl_t::global>(10);
218+
219+
// basic gemm_universal with thread cooperation
220+
// gemm_universal_run<arch_tag, kslicing_impl_t::local>(10);
221+
};
222+
};
187223
int main() {
188-
// An example code for calculating matrix multiplication using
189-
// GEMM_UNIVERSAL API:
190-
// C = A x B
191-
// The resulted matrix C is partitioned by the group range
192-
// in to multiple blocks. The block matrix
193-
// C<i_w, j_w>
194-
// is computed by the workgroup with id: (0, i_w, j_w).
195-
// (i_w, j_w) is an element in range specified by group range.
196-
// Each thread with index (0, i_s, j_s) inside the same workgroup
197-
// is responsible for a sub block of matrix multiplication, which is
198-
// C<i_w, j_w>[i_s*sg_m:(i_s+1):sg_m,j_s*sg_n:(j_s+1)*sg_n]
199-
200-
// Alternatively, some threads can cooperate on the same sub block
201-
// matrix given the same (i_s, j_s), i.e. the index space is extended
202-
// from (0, i_s, j_s) to (k_s, i_s, j_s).
203-
204-
// Another method to achieve the same effect is to extend the index space
205-
// in group range, i.e. from (0, i_w, j_w) to (k_w, i_w, j_w)
206-
207-
// More detailed description referring to the cooperation (kslicing) could
208-
// be found in the example 01_gemm_universal with custom implementation
209-
210-
// basic gemm_universal
211-
gemm_universal_run<kslicing_impl_t::none>(10);
212-
213-
// basic gemm_universal with workgroup cooperation
214-
// gemm_universal_run<kslicing_impl_t::global>(10);
215-
216-
// basic gemm_universal with thread cooperation
217-
// gemm_universal_run<kslicing_impl_t::local>(10);
218-
return (0);
224+
dispatch_arch<main_wrapper>::exec();
225+
return 0;
219226
}

examples/02_basic_gemm/basic_gemm.cpp

Lines changed: 30 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@
1313
* See the License for the specific language governing permissions and
1414
* limitations under the License.
1515
*******************************************************************************/
16-
#include <tests/utils/utils.hpp>
1716
#include "xetla.hpp"
17+
#include <tests/utils/utils.hpp>
1818

19-
template <gpu_arch arch_tag_>
19+
template <gpu_arch arch_tag>
2020
void basic_gemm_run(sycl::queue queue, uint32_t iter) {
2121
// Tips, the example demonstrates programming kernel with XeTLA, it works as expected with current configurations.
2222
// Please make sure you fully understand these configurations before you do any modifications, incomplete changes may lead to unexpected behaviors.
@@ -110,11 +110,11 @@ void basic_gemm_run(sycl::queue queue, uint32_t iter) {
110110
// should larger than 8
111111
static constexpr uint32_t k_stride = 32;
112112

113-
// Step 1: define mirco-kernel's configuration
113+
// Step 1: define Micro-kernel's configuration
114114
using wg_shape = shape<wg_tile_n, wg_tile_m>;
115115
using sg_shape = shape<sg_tile_n, sg_tile_m>;
116116

117-
// Mirco-kernel configuration
117+
// Micro-kernel configuration
118118
using gemm_tune_option
119119
= dict_t<elem_t_t<tune_key::sg_tile_shape, sg_shape>,
120120
elem_v_t<tune_key::prefetch_distance,
@@ -132,10 +132,10 @@ void basic_gemm_run(sycl::queue queue, uint32_t iter) {
132132
8, // leading dimension for B, in unit of element
133133
mem_space::
134134
global, // memory reading from global mem for B
135-
data_type_acc, // accumulator data type for intermediate resutls
135+
data_type_acc, // accumulator data type for intermediate results
136136
wg_shape, // computation tile shape
137137
k_stride, // elements in each iteration
138-
gpu_arch::Xe, // GPU arch
138+
arch_tag, // GPU arch
139139
gemm_tune_option>;
140140
gemm_t gemm;
141141

@@ -149,24 +149,26 @@ void basic_gemm_run(sycl::queue queue, uint32_t iter) {
149149
mem_space::global, // memory writing to global mem for C
150150
wg_shape, // computation tile shape
151151
k_stride, // elements in each iteration
152-
gpu_arch::Xe, // GPU arch
152+
arch_tag, // GPU arch
153153
epilogue_tune_option>;
154154

155155
// Step 3: define the shared local memory usages
156156
// developers have the responsibility to set
157-
// shared loacal memory through XeTLA API
157+
// shared local memory through XeTLA API
158158
static constexpr uint32_t barrier_count = gemm_t::barrier_count;
159159
static constexpr uint32_t slm_size = gemm_t::slm_size;
160+
static_assert(slm_size <= arch_attr_t<arch_tag>::local_mem_size,
161+
"The local memory size excess!");
160162
xetla_nbarrier_init<barrier_count>();
161163
xetla_local_init<slm_size>();
162164

163-
// Step 4: ecah workgroup gets it individual index to start computation
165+
// Step 4: each workgroup gets it individual index to start computation
164166
int start_n = item.get_group(2) * wg_tile_n;
165167
int start_m = item.get_group(1) * wg_tile_m;
166168
// no slicing in K direction so start from zero for all WG
167169
int start_k = 0;
168170

169-
// Each workgroup will compute all data in K based on no k_sliciing
171+
// Each workgroup will compute all data in K based on no k_slicing
170172
// The developer can set how much data a subgroup compute by k_stride
171173
uint32_t wg_tile_k = matrix_k;
172174
uint32_t inner_loop_count
@@ -183,7 +185,7 @@ void basic_gemm_run(sycl::queue queue, uint32_t iter) {
183185
mem_desc_output_c md_c(
184186
{C}, {matrix_n, matrix_m, ldc}, {start_n, start_m});
185187

186-
// Step 6: real calculation with accumulator varibales which suppose
188+
// Step 6: real calculation with accumulator variables which suppose
187189
// will be in register.
188190
typename gemm_t::matAcc_t matAcc;
189191
matAcc.init(0);
@@ -194,8 +196,7 @@ void basic_gemm_run(sycl::queue queue, uint32_t iter) {
194196
// the results is in the matAcc rather than real output C
195197
typename gemm_t::work_group_t g(item.get_local_linear_id());
196198
gemm(g, matAcc, gemm_args);
197-
198-
// Step 7: write the results from matACC to real output C
199+
// Step 7: write the results from matAcc to real output C
199200
epilogue_t epilogue;
200201
epilogue(g, matAcc, md_c);
201202
});
@@ -220,23 +221,21 @@ void basic_gemm_run(sycl::queue queue, uint32_t iter) {
220221
free(C, context);
221222
}
222223

224+
template <gpu_arch arch_tag>
225+
struct main_wrapper {
226+
static constexpr auto exec = []() {
227+
// This case shows how to use batch-reduce (br) GEMM microkernel to
228+
// solve a standard GEMM
229+
// Turn on the profiling property to facilitate subsequent profiling
230+
sycl::property_list properties {
231+
sycl::property::queue::enable_profiling()};
232+
233+
// Define SYCL queue, context and device
234+
auto queue = sycl::queue(properties);
235+
basic_gemm_run<arch_tag>(queue, 10);
236+
};
237+
};
223238
int main() {
224-
// This case shows how to use batch-reduce (br) GEMM microkernel to
225-
// solve a standard GEMM
226-
// Turn on the profiling property to facilitate subsequent profiling
227-
sycl::property_list properties {sycl::property::queue::enable_profiling()};
228-
229-
// Define SYCL queue, context and device
230-
auto queue = sycl::queue(properties);
231-
auto device = queue.get_device();
232-
233-
// Detect the execution size, 8 for Arc, 16 for PVC.
234-
int ExecSize
235-
= device.get_info<ext::intel::info::device::gpu_eu_simd_width>();
236-
if (ExecSize == 8) {
237-
basic_gemm_run<gpu_arch::Dg2>(queue, 10);
238-
} else {
239-
basic_gemm_run<gpu_arch::Xe>(queue, 10);
240-
}
241-
return (0);
239+
dispatch_arch<main_wrapper>::exec();
240+
return 0;
242241
}

examples/03_gemm_relu_bias/gemm_relu_bias.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
* limitations under the License.
1515
*******************************************************************************/
1616
#include <algorithm>
17-
#include <tests/utils/utils.hpp>
1817
#include "xetla.hpp"
18+
#include <tests/utils/utils.hpp>
1919

2020
using namespace cl::sycl;
2121
using namespace gpu::xetla;
@@ -140,7 +140,7 @@ void gemm_relu_bias_run(uint32_t iter) {
140140
using epilogue_policy
141141
= xetla::group::epilogue_policy_tile_op<tile_op_t, gpu_arch::Xe>;
142142

143-
// Mirco-kernel configuration
143+
// Micro-kernel configuration
144144
using tune_option = dict_t<
145145
elem_v_t<tune_key::param_optimizer_type,
146146
tune_key_value::param_optimizer_decision_tree>,
@@ -156,7 +156,7 @@ void gemm_relu_bias_run(uint32_t iter) {
156156
data_type_c, // output datatype for C
157157
mem_layout::row_major, // memory layout for C
158158
8, // leading dimension alignment for C, in unit of element
159-
data_type_acc, // accumulator data type for intermediate resutls
159+
data_type_acc, // accumulator data type for intermediate results
160160
gpu_arch::Xe, // GPU arch
161161
tune_option>;
162162
using gemm_op_t = typename default_config_t::type;
@@ -223,15 +223,15 @@ int main() {
223223
// The purpose of this example is to illustrate the epilogue_t API in XeTLA.
224224

225225
// It allows user to implement multiple Ops inside a kernel call to avoid
226-
// overheads in invokation, memory transfer, etc.
226+
// overheads in invocation, memory transfer, etc.
227227
// Take the following python code as an example:
228228

229229
// Original:
230230
// > import torch as to
231231
// > x = to.matmul(A, B)
232232
// > y = to.nn.functional.relu(x)
233233

234-
// It takes two kernel invokations and the ReLU Op is a elementwise operation
234+
// It takes two kernel invocations and the ReLU Op is a elementwise operation
235235
// that could be fused into MatMul Op, which is basically calling GEMM kernel.
236236

237237
// Fusion:

examples/04_gemm_polynomial/gemm_polynomial.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
* limitations under the License.
1515
*******************************************************************************/
1616
#include <algorithm>
17-
#include <tests/utils/utils.hpp>
1817
#include "xetla.hpp"
18+
#include <tests/utils/utils.hpp>
1919

2020
#include "gemm_polynomial.hpp"
2121

@@ -137,7 +137,7 @@ void gemm_polynomial_run(int iter) {
137137
using epilogue_policy
138138
= xetla::group::epilogue_policy_tile_op<tile_op_t, gpu_arch::Xe>;
139139

140-
// Mirco-kernel configuration
140+
// Micro-kernel configuration
141141
using tune_option = dict_t<
142142
elem_v_t<tune_key::param_optimizer_type,
143143
tune_key_value::param_optimizer_decision_tree>,
@@ -154,7 +154,7 @@ void gemm_polynomial_run(int iter) {
154154
data_type_c, // output datatype for C
155155
mem_layout::row_major, // memory layout for C
156156
8, // leading dimension alignment for C, in unit of element
157-
data_type_acc, // accumulator data type for intermediate resutls
157+
data_type_acc, // accumulator data type for intermediate results
158158
gpu_arch::Xe, // GPU arch
159159
tune_option>;
160160

examples/05_batch_gemm/batch_gemm.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ void batch_gemm_run(uint32_t iter) {
9090
using wg_shape = shape<wg_tile_n, wg_tile_m>;
9191
using sg_shape = shape<sg_tile_n, sg_tile_m>;
9292

93-
// Mirco-kernel configuration
93+
// Micro-kernel configuration
9494
using tune_option
9595
= dict_t<elem_v_t<tune_key::param_optimizer_type,
9696
tune_key_value::param_optimizer_decision_tree>,
@@ -106,7 +106,7 @@ void batch_gemm_run(uint32_t iter) {
106106
mem_layout::row_major, // memory layout for B
107107
8, // leading dimension for B, in unit of element
108108
mem_space::global, // memory reading from global mem for B
109-
data_type_acc, // accumulator data type for intermediate resutls
109+
data_type_acc, // accumulator data type for intermediate results
110110
wg_shape, // computation tile shape
111111
wg_tile_k, // elements in each iteration
112112
gpu_arch::Xe, // GPU arch

examples/05_batch_gemm/batch_gemm.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,8 @@ class batch_gemm_t {
173173
/// @return The size of local memory required.
174174
__XETLA_API static constexpr uint32_t get_slm_size() {
175175
constexpr uint32_t size = gemm_t::slm_size + epilogue_t::slm_size;
176-
static_assert(size <= (128 * 1024),
177-
"The local memory size should be less than 128KB!");
176+
static_assert(size <= arch_attr_t<arch_tag>::local_mem_size,
177+
"The local memory size excess!");
178178
return size;
179179
};
180180

0 commit comments

Comments
 (0)