1+ // *****************************************************************************
2+ // Copyright (c) 2024, Intel Corporation
3+ // All rights reserved.
4+ //
5+ // Redistribution and use in source and binary forms, with or without
6+ // modification, are permitted provided that the following conditions are met:
7+ // - Redistributions of source code must retain the above copyright notice,
8+ // this list of conditions and the following disclaimer.
9+ // - Redistributions in binary form must reproduce the above copyright notice,
10+ // this list of conditions and the following disclaimer in the documentation
11+ // and/or other materials provided with the distribution.
12+ //
13+ // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+ // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+ // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+ // ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+ // LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+ // CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+ // SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+ // INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+ // CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+ // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+ // THE POSSIBILITY OF SUCH DAMAGE.
24+ // *****************************************************************************
25+
26+ #pragma once
27+
28+ #include < pybind11/pybind11.h>
29+
30+ #include < sycl/sycl.hpp>
31+ #include < oneapi/mkl/rng/device.hpp>
32+
33+ // dpctl tensor headers
34+ #include " kernels/alignment.hpp"
35+ #include " utils/offset_utils.hpp"
36+
37+ namespace dpnp
38+ {
39+ namespace backend
40+ {
41+ namespace ext
42+ {
43+ namespace rng
44+ {
45+ namespace device
46+ {
47+ namespace details
48+ {
49+ namespace py = pybind11;
50+
51+ using dpctl::tensor::kernels::alignment_utils::is_aligned;
52+ using dpctl::tensor::kernels::alignment_utils::required_alignment;
53+
54+ namespace mkl_rng_dev = oneapi::mkl::rng::device;
55+
56+ /* ! @brief Functor for unary function evaluation on contiguous array */
57+ template <typename DataT,
58+ typename ResT = DataT,
59+ typename Method = mkl_rng_dev::gaussian_method::by_default,
60+ typename UnaryOperatorT = ResT,
61+ unsigned int vec_sz = 8 ,
62+ unsigned int items_per_wi = 4 ,
63+ bool enable_sg_load = true >
64+ struct RngContigFunctor
65+ {
66+ private:
67+ const std::uint32_t seed_;
68+ const DataT mean_;
69+ const DataT stddev_;
70+ ResT *res_ = nullptr ;
71+ const size_t nelems_;
72+
73+ public:
74+ RngContigFunctor (const std::uint32_t seed, const DataT mean, const DataT stddev, ResT *res, const size_t n_elems)
75+ : seed_(seed), mean_(mean), stddev_(stddev), res_(res), nelems_(n_elems)
76+ {
77+ }
78+
79+ void operator ()(sycl::nd_item<1 > nd_it) const
80+ {
81+ auto global_id = nd_it.get_global_id ();
82+
83+ auto sg = nd_it.get_sub_group ();
84+ const std::uint8_t sg_size = sg.get_local_range ()[0 ];
85+ const std::uint8_t max_sg_size = sg.get_max_local_range ()[0 ];
86+
87+ auto engine = mkl_rng_dev::mrg32k3a<vec_sz>(seed_, nelems_ * global_id);
88+ mkl_rng_dev::gaussian<DataT, Method> distr (mean_, stddev_);
89+
90+ if (enable_sg_load) {
91+ const size_t base = items_per_wi * vec_sz * (nd_it.get_group (0 ) * nd_it.get_local_range (0 ) + sg.get_group_id ()[0 ] * max_sg_size);
92+
93+ if ((sg_size == max_sg_size) && (base + items_per_wi * vec_sz * sg_size < nelems_)) {
94+ #pragma unroll
95+ for (std::uint16_t it = 0 ; it < items_per_wi * vec_sz; it += vec_sz) {
96+ size_t offset = base + static_cast <size_t >(it) * static_cast <size_t >(sg_size);
97+ auto out_multi_ptr = sycl::address_space_cast<sycl::access::address_space::global_space, sycl::access::decorated::yes>(&res_[offset]);
98+
99+ sycl::vec<DataT, vec_sz> rng_val_vec = mkl_rng_dev::generate (distr, engine);
100+ sg.store <vec_sz>(out_multi_ptr, rng_val_vec);
101+ }
102+ }
103+ else {
104+ for (size_t offset = base + sg.get_local_id ()[0 ]; offset < nelems_; offset += sg_size) {
105+ res_[offset] = mkl_rng_dev::generate_single (distr, engine);
106+ }
107+ }
108+ }
109+ else {
110+ size_t base = nd_it.get_global_linear_id ();
111+
112+ base = (base / sg_size) * sg_size * items_per_wi * vec_sz + (base % sg_size);
113+ for (size_t offset = base; offset < std::min (nelems_, base + sg_size * (items_per_wi * vec_sz)); offset += sg_size)
114+ {
115+ res_[offset] = mkl_rng_dev::generate_single (distr, engine);
116+ }
117+ }
118+ }
119+ };
120+
121+ template <typename DataT,
122+ typename ResT = DataT,
123+ typename Method = mkl_rng_dev::gaussian_method::by_default,
124+ typename IndexerT = ResT,
125+ typename UnaryOpT = ResT>
126+ struct RngStridedFunctor
127+ {
128+ private:
129+ const std::uint32_t seed_;
130+ const double mean_;
131+ const double stddev_;
132+ ResT *res_ = nullptr ;
133+ IndexerT out_indexer_;
134+
135+ public:
136+ RngStridedFunctor (const std::uint32_t seed, const double mean, const double stddev, ResT *res_p, IndexerT out_indexer)
137+ : seed_(seed), mean_(mean), stddev_(stddev), res_(res_p), out_indexer_(out_indexer)
138+ {
139+ }
140+
141+ void operator ()(sycl::id<1 > wid) const
142+ {
143+ const auto res_offset = out_indexer_ (wid.get (0 ));
144+
145+ // UnaryOpT op{};
146+
147+ auto engine = mkl_rng_dev::mrg32k3a (seed_);
148+ mkl_rng_dev::gaussian<DataT, Method> distr (mean_, stddev_);
149+
150+ res_[res_offset] = mkl_rng_dev::generate (distr, engine);
151+ }
152+ };
153+ } // namespace details
154+ } // namespace device
155+ } // namespace rng
156+ } // namespace ext
157+ } // namespace backend
158+ } // namespace dpnp
0 commit comments