1212#ifdef USE_IPEX
1313#include < ipex.h>
1414#else
15+ #include < ATen/record_function.h>
1516#include < c10/xpu/XPUStream.h>
1617#endif
1718
@@ -43,9 +44,7 @@ at::Tensor softmax(const at::Tensor &input, const at::Tensor &output,
4344 const int64_t dim) {
4445 CHECK_INPUT (input);
4546 CHECK_INPUT (output);
46- #ifdef USE_IPEX
4747 RECORD_FUNCTION (" xetla softmax" , {});
48- #endif
4948
5049 auto queue = get_current_sycl_queue ();
5150 auto evt = softmax_forward<T>(input.data_ptr (), output.data_ptr (), queue);
@@ -63,9 +62,7 @@ at::Tensor bf16_gemm(const at::Tensor &a, const at::Tensor &b,
6362 CHECK_INPUT (b);
6463 CHECK_INPUT (c);
6564 CHECK_INPUT (acc);
66- #ifdef USE_IPEX
6765 RECORD_FUNCTION (" xetla gemm" , {});
68- #endif
6966
7067 auto queue = get_current_sycl_queue ();
7168 auto evt = gemm_run<T>(a.data_ptr (), b.data_ptr (), c.data_ptr (),
@@ -83,9 +80,7 @@ at::Tensor bf16_stream_k_gemm(const at::Tensor &a, const at::Tensor &b,
8380 CHECK_INPUT (b);
8481 CHECK_INPUT (c);
8582 CHECK_INPUT (acc);
86- #ifdef USE_IPEX
8783 RECORD_FUNCTION (" xetla stream_k_gemm" , {});
88- #endif
8984
9085 auto queue = get_current_sycl_queue ();
9186 auto evt = stream_k_gemm_run (a.data_ptr (), b.data_ptr (), c.data_ptr (),
@@ -105,9 +100,7 @@ at::Tensor bf16_split_k_gemm(const at::Tensor &a, const at::Tensor &b,
105100 CHECK_INPUT (b);
106101 CHECK_INPUT (c);
107102 CHECK_INPUT (acc);
108- #ifdef USE_IPEX
109103 RECORD_FUNCTION (" xetla split_k_gemm" , {});
110- #endif
111104
112105 auto queue = get_current_sycl_queue ();
113106 auto evt = split_k_gemm_run<m, k, n, kslicing_type>(
@@ -143,9 +136,7 @@ void flash_attn(const at::Tensor &q, const at::Tensor &k, const at::Tensor &v,
143136 CHECK_INPUT (bias);
144137 CHECK_INPUT (m);
145138 CHECK_INPUT (l);
146- #ifdef USE_IPEX
147139 RECORD_FUNCTION (" xetla fa" , {});
148- #endif
149140
150141 auto queue = get_current_sycl_queue ();
151142
@@ -212,9 +203,7 @@ void flash_attn_bwd(const at::Tensor &grad_out, const at::Tensor &q,
212203 CHECK_INPUT (grad_value);
213204 CHECK_INPUT (grad_bias);
214205
215- #ifdef USE_IPEX
216206 RECORD_FUNCTION (" xetla fa" , {});
217- #endif
218207
219208 auto queue = get_current_sycl_queue ();
220209
0 commit comments