|
24 | 24 | //===----------------------------------------------------------------------===// |
25 | 25 |
|
26 | 26 | #pragma once |
| 27 | +#include <complex> |
| 28 | +#include <cstddef> |
| 29 | + |
| 30 | +#include <sycl/sycl.hpp> |
| 31 | + |
27 | 32 | #include "dpctl_tensor_types.hpp" |
28 | 33 | #include "utils/offset_utils.hpp" |
29 | 34 | #include "utils/strided_iters.hpp" |
30 | 35 | #include "utils/type_utils.hpp" |
31 | | -#include <complex> |
32 | | -#include <cstddef> |
33 | | -#include <sycl/sycl.hpp> |
34 | 36 |
|
35 | 37 | namespace dpctl |
36 | 38 | { |
@@ -200,22 +202,25 @@ sycl::event lin_space_affine_impl(sycl::queue &exec_q, |
200 | 202 | { |
201 | 203 | dpctl::tensor::type_utils::validate_type_for_device<Ty>(exec_q); |
202 | 204 |
|
203 | | - bool device_supports_doubles = exec_q.get_device().has(sycl::aspect::fp64); |
| 205 | + const bool device_supports_doubles = |
| 206 | + exec_q.get_device().has(sycl::aspect::fp64); |
| 207 | + const std::size_t den = (include_endpoint) ? nelems - 1 : nelems; |
| 208 | + |
204 | 209 | sycl::event lin_space_affine_event = exec_q.submit([&](sycl::handler &cgh) { |
205 | 210 | cgh.depends_on(depends); |
206 | 211 | if (device_supports_doubles) { |
207 | | - cgh.parallel_for<linear_sequence_affine_kernel<Ty, double>>( |
208 | | - sycl::range<1>{nelems}, |
209 | | - LinearSequenceAffineFunctor<Ty, double>( |
210 | | - array_data, start_v, end_v, |
211 | | - (include_endpoint) ? nelems - 1 : nelems)); |
| 212 | + using KernelName = linear_sequence_affine_kernel<Ty, double>; |
| 213 | + using Impl = LinearSequenceAffineFunctor<Ty, double>; |
| 214 | + |
| 215 | + cgh.parallel_for<KernelName>(sycl::range<1>{nelems}, |
| 216 | + Impl(array_data, start_v, end_v, den)); |
212 | 217 | } |
213 | 218 | else { |
214 | | - cgh.parallel_for<linear_sequence_affine_kernel<Ty, float>>( |
215 | | - sycl::range<1>{nelems}, |
216 | | - LinearSequenceAffineFunctor<Ty, float>( |
217 | | - array_data, start_v, end_v, |
218 | | - (include_endpoint) ? nelems - 1 : nelems)); |
| 219 | + using KernelName = linear_sequence_affine_kernel<Ty, float>; |
| 220 | + using Impl = LinearSequenceAffineFunctor<Ty, float>; |
| 221 | + |
| 222 | + cgh.parallel_for<KernelName>(sycl::range<1>{nelems}, |
| 223 | + Impl(array_data, start_v, end_v, den)); |
219 | 224 | } |
220 | 225 | }); |
221 | 226 |
|
@@ -312,10 +317,12 @@ sycl::event full_strided_impl(sycl::queue &q, |
312 | 317 |
|
313 | 318 | sycl::event fill_ev = q.submit([&](sycl::handler &cgh) { |
314 | 319 | cgh.depends_on(depends); |
315 | | - cgh.parallel_for<full_strided_kernel<dstTy>>( |
316 | | - sycl::range<1>{nelems}, |
317 | | - FullStridedFunctor<dstTy, decltype(strided_indexer)>( |
318 | | - dst_tp, fill_v, strided_indexer)); |
| 320 | + |
| 321 | + using KernelName = full_strided_kernel<dstTy>; |
| 322 | + using Impl = FullStridedFunctor<dstTy, StridedIndexer>; |
| 323 | + |
| 324 | + cgh.parallel_for<KernelName>(sycl::range<1>{nelems}, |
| 325 | + Impl(dst_tp, fill_v, strided_indexer)); |
319 | 326 | }); |
320 | 327 |
|
321 | 328 | return fill_ev; |
@@ -388,9 +395,12 @@ sycl::event eye_impl(sycl::queue &exec_q, |
388 | 395 | dpctl::tensor::type_utils::validate_type_for_device<Ty>(exec_q); |
389 | 396 | sycl::event eye_event = exec_q.submit([&](sycl::handler &cgh) { |
390 | 397 | cgh.depends_on(depends); |
391 | | - cgh.parallel_for<eye_kernel<Ty>>( |
392 | | - sycl::range<1>{nelems}, |
393 | | - EyeFunctor<Ty>(array_data, start, end, step)); |
| 398 | + |
| 399 | + using KernelName = eye_kernel<Ty>; |
| 400 | + using Impl = EyeFunctor<Ty>; |
| 401 | + |
| 402 | + cgh.parallel_for<KernelName>(sycl::range<1>{nelems}, |
| 403 | + Impl(array_data, start, end, step)); |
394 | 404 | }); |
395 | 405 |
|
396 | 406 | return eye_event; |
@@ -478,7 +488,7 @@ sycl::event tri_impl(sycl::queue &exec_q, |
478 | 488 | ssize_t inner_gid = idx[0] - inner_range * outer_gid; |
479 | 489 |
|
480 | 490 | ssize_t src_inner_offset = 0, dst_inner_offset = 0; |
481 | | - bool to_copy(true); |
| 491 | + bool to_copy{false}; |
482 | 492 |
|
483 | 493 | { |
484 | 494 | using dpctl::tensor::strides::CIndexer_array; |
|
0 commit comments