Skip to content

Commit 8d16f1c

Browse files
authored
Call RECORD_FUNCTION not only for IPEX on XeTLA benchmarks (#3027)
The same way as for Triton benchmarks, let's see how it works. Ref: #2510 https://github.com/intel/intel-xpu-backend-for-triton/actions/runs/12379668850 (upstream profiler) Signed-off-by: Anatoly Myachev <anatoly.myachev@intel.com>
1 parent aa7a897 commit 8d16f1c

File tree

1 file changed

+1
-12
lines changed

1 file changed

+1
-12
lines changed

benchmarks/xetla_kernel/python_main.cpp

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#ifdef USE_IPEX
1313
#include <ipex.h>
1414
#else
15+
#include <ATen/record_function.h>
1516
#include <c10/xpu/XPUStream.h>
1617
#endif
1718

@@ -43,9 +44,7 @@ at::Tensor softmax(const at::Tensor &input, const at::Tensor &output,
4344
const int64_t dim) {
4445
CHECK_INPUT(input);
4546
CHECK_INPUT(output);
46-
#ifdef USE_IPEX
4747
RECORD_FUNCTION("xetla softmax", {});
48-
#endif
4948

5049
auto queue = get_current_sycl_queue();
5150
auto evt = softmax_forward<T>(input.data_ptr(), output.data_ptr(), queue);
@@ -63,9 +62,7 @@ at::Tensor bf16_gemm(const at::Tensor &a, const at::Tensor &b,
6362
CHECK_INPUT(b);
6463
CHECK_INPUT(c);
6564
CHECK_INPUT(acc);
66-
#ifdef USE_IPEX
6765
RECORD_FUNCTION("xetla gemm", {});
68-
#endif
6966

7067
auto queue = get_current_sycl_queue();
7168
auto evt = gemm_run<T>(a.data_ptr(), b.data_ptr(), c.data_ptr(),
@@ -83,9 +80,7 @@ at::Tensor bf16_stream_k_gemm(const at::Tensor &a, const at::Tensor &b,
8380
CHECK_INPUT(b);
8481
CHECK_INPUT(c);
8582
CHECK_INPUT(acc);
86-
#ifdef USE_IPEX
8783
RECORD_FUNCTION("xetla stream_k_gemm", {});
88-
#endif
8984

9085
auto queue = get_current_sycl_queue();
9186
auto evt = stream_k_gemm_run(a.data_ptr(), b.data_ptr(), c.data_ptr(),
@@ -105,9 +100,7 @@ at::Tensor bf16_split_k_gemm(const at::Tensor &a, const at::Tensor &b,
105100
CHECK_INPUT(b);
106101
CHECK_INPUT(c);
107102
CHECK_INPUT(acc);
108-
#ifdef USE_IPEX
109103
RECORD_FUNCTION("xetla split_k_gemm", {});
110-
#endif
111104

112105
auto queue = get_current_sycl_queue();
113106
auto evt = split_k_gemm_run<m, k, n, kslicing_type>(
@@ -143,9 +136,7 @@ void flash_attn(const at::Tensor &q, const at::Tensor &k, const at::Tensor &v,
143136
CHECK_INPUT(bias);
144137
CHECK_INPUT(m);
145138
CHECK_INPUT(l);
146-
#ifdef USE_IPEX
147139
RECORD_FUNCTION("xetla fa", {});
148-
#endif
149140

150141
auto queue = get_current_sycl_queue();
151142

@@ -212,9 +203,7 @@ void flash_attn_bwd(const at::Tensor &grad_out, const at::Tensor &q,
212203
CHECK_INPUT(grad_value);
213204
CHECK_INPUT(grad_bias);
214205

215-
#ifdef USE_IPEX
216206
RECORD_FUNCTION("xetla fa", {});
217-
#endif
218207

219208
auto queue = get_current_sycl_queue();
220209

0 commit comments

Comments
 (0)