@@ -47,6 +47,12 @@ using namespace dpctl::tensor::offset_utils;
4747template <typename srcT, typename dstT, typename IndexerT>
4848class copy_cast_generic_kernel ;
4949
50+ template <typename srcT,
51+ typename dstT,
52+ unsigned int vec_sz,
53+ unsigned int n_vecs>
54+ class copy_cast_contig_kernel ;
55+
5056template <typename srcT, typename dstT, typename IndexerT>
5157class copy_cast_from_host_kernel ;
5258
@@ -191,6 +197,166 @@ template <typename fnT, typename D, typename S> struct CopyAndCastGenericFactory
191197 }
192198};
193199
200+ // Specialization of copy_and_cast for contiguous arrays of different data types
201+
202+ template <typename srcT,
203+ typename dstT,
204+ typename CastFnT,
205+ int vec_sz = 4 ,
206+ int n_vecs = 2 >
207+ class ContigCopyFunctor
208+ {
209+ private:
210+ const size_t nelems;
211+ const srcT *src_p = nullptr ;
212+ dstT *dst_p = nullptr ;
213+
214+ public:
215+ ContigCopyFunctor (const size_t nelems_, const srcT *src_p_, dstT *dst_p_)
216+ : nelems(nelems_), src_p(src_p_), dst_p(dst_p_)
217+ {
218+ }
219+
220+ void operator ()(sycl::nd_item<1 > ndit) const
221+ {
222+ CastFnT fn{};
223+
224+ using dpctl::tensor::type_utils::is_complex;
225+ if constexpr (is_complex<srcT>::value || is_complex<dstT>::value) {
226+ std::uint8_t sgSize = ndit.get_sub_group ().get_local_range ()[0 ];
227+ size_t base = ndit.get_global_linear_id ();
228+
229+ base = (base / sgSize) * sgSize * n_vecs * vec_sz + (base % sgSize);
230+ for (size_t offset = base;
231+ offset < std::min (nelems, base + sgSize * (n_vecs * vec_sz));
232+ offset += sgSize)
233+ {
234+ dst_p[offset] = fn (src_p[offset]);
235+ }
236+ }
237+ else {
238+ auto sg = ndit.get_sub_group ();
239+ std::uint8_t sgSize = sg.get_local_range ()[0 ];
240+ std::uint8_t max_sgSize = sg.get_max_local_range ()[0 ];
241+ size_t base = n_vecs * vec_sz *
242+ (ndit.get_group (0 ) * ndit.get_local_range (0 ) +
243+ sg.get_group_id ()[0 ] * max_sgSize);
244+
245+ if (base + n_vecs * vec_sz * sgSize < nelems &&
246+ sgSize == max_sgSize) {
247+ using src_ptrT =
248+ sycl::multi_ptr<const srcT,
249+ sycl::access::address_space::global_space>;
250+ using dst_ptrT =
251+ sycl::multi_ptr<dstT,
252+ sycl::access::address_space::global_space>;
253+ sycl::vec<srcT, vec_sz> src_vec;
254+ sycl::vec<dstT, vec_sz> dst_vec;
255+
256+ #pragma unroll
257+ for (std::uint8_t it = 0 ; it < n_vecs * vec_sz; it += vec_sz) {
258+ src_vec =
259+ sg.load <vec_sz>(src_ptrT (&src_p[base + it * sgSize]));
260+ #pragma unroll
261+ for (std::uint8_t k = 0 ; k < vec_sz; k++) {
262+ dst_vec[k] = fn (src_vec[k]);
263+ }
264+ sg.store <vec_sz>(dst_ptrT (&dst_p[base + it * sgSize]),
265+ dst_vec);
266+ }
267+ }
268+ else {
269+ for (size_t k = base + sg.get_local_id ()[0 ]; k < nelems;
270+ k += sgSize) {
271+ dst_p[k] = fn (src_p[k]);
272+ }
273+ }
274+ }
275+ }
276+ };
277+
278+ /* !
279+ * @brief Function pointer type for contiguous array cast and copy function.
280+ */
281+ typedef sycl::event (*copy_and_cast_contig_fn_ptr_t )(
282+ sycl::queue,
283+ size_t ,
284+ const char *,
285+ char *,
286+ const std::vector<sycl::event> &);
287+
288+ /* !
289+ * @brief Function to copy `nelems` elements from contiguous `src` usm_ndarray
290+ to contiguous `dst` usm_ndarray while casting from `srcTy` to `dstTy`.
291+
292+ Both arrays have the same number of elements `nelems`.
293+ `src_cp` and `dst_cp` represent char pointers to the start of respective
294+ arrays. Kernel is submitted to sycl queue `q` with events `depends` as
295+ dependencies.
296+
297+ @param q Sycl queue to which the kernel is submitted.
298+ @param nelems Number of elements to cast and copy.
299+ @param src_p Kernel accessible USM pointer for the source array
300+ @param dst_p Kernel accessible USM pointer for the destination array
301+ @param depends List of events to wait for before starting computations, if
302+ any.
303+
304+ @return Event to wait on to ensure that computation completes.
305+ @ingroup CopyAndCastKernels
306+ */
307+ template <typename dstTy, typename srcTy>
308+ sycl::event copy_and_cast_contig_impl (sycl::queue q,
309+ size_t nelems,
310+ const char *src_cp,
311+ char *dst_cp,
312+ const std::vector<sycl::event> &depends)
313+ {
314+ dpctl::tensor::type_utils::validate_type_for_device<dstTy>(q);
315+ dpctl::tensor::type_utils::validate_type_for_device<srcTy>(q);
316+
317+ sycl::event copy_and_cast_ev = q.submit ([&](sycl::handler &cgh) {
318+ cgh.depends_on (depends);
319+
320+ const srcTy *src_tp = reinterpret_cast <const srcTy *>(src_cp);
321+ dstTy *dst_tp = reinterpret_cast <dstTy *>(dst_cp);
322+
323+ size_t lws = 64 ;
324+ constexpr unsigned int vec_sz = 4 ;
325+ constexpr unsigned int n_vecs = 2 ;
326+ const size_t n_groups =
327+ ((nelems + lws * n_vecs * vec_sz - 1 ) / (lws * n_vecs * vec_sz));
328+ const auto gws_range = sycl::range<1 >(n_groups * lws);
329+ const auto lws_range = sycl::range<1 >(lws);
330+
331+ cgh.parallel_for <copy_cast_contig_kernel<srcTy, dstTy, n_vecs, vec_sz>>(
332+ sycl::nd_range<1 >(gws_range, lws_range),
333+ ContigCopyFunctor<srcTy, dstTy, Caster<srcTy, dstTy>, vec_sz,
334+ n_vecs>(nelems, src_tp, dst_tp));
335+ });
336+
337+ return copy_and_cast_ev;
338+ }
339+
340+ /* !
341+ * @brief Factory to get specialized function pointer for casting and copying
342+ * contiguous arrays of different types.
343+ * @ingroup CopyAndCastKernels
344+ */
345+ template <typename fnT, typename D, typename S> struct CopyAndCastContigFactory
346+ {
347+ fnT get ()
348+ {
349+ if constexpr (std::is_same_v<D, S>) {
350+ fnT fn = nullptr ;
351+ return fn;
352+ }
353+ else {
354+ fnT f = copy_and_cast_contig_impl<D, S>;
355+ return f;
356+ }
357+ }
358+ };
359+
194360// Specialization of copy_and_cast for 1D arrays
195361
196362/* !
0 commit comments