Skip to content

Commit 279d463

Browse files
committed
FEATURE: random number generation functions
* Moved old rand* functions to random module * New object `RandomEngine` is added * Additional random number generation functions that use RandomEngine object are added.
1 parent 8330324 commit 279d463

File tree

5 files changed

+308
-52
lines changed

5 files changed

+308
-52
lines changed

src/data/mod.rs

Lines changed: 16 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,6 @@ extern {
3838
fn af_iota(out: MutAfArray, ndims: c_uint, dims: *const DimT,
3939
t_ndims: c_uint, tdims: *const DimT, afdtype: uint8_t) -> c_int;
4040

41-
fn af_randu(out: MutAfArray, ndims: c_uint, dims: *const DimT, afdtype: uint8_t) -> c_int;
42-
fn af_randn(out: MutAfArray, ndims: c_uint, dims: *const DimT, afdtype: uint8_t) -> c_int;
43-
44-
fn af_set_seed(seed: Uintl) -> c_int;
45-
fn af_get_seed(seed: *mut Uintl) -> c_int;
46-
4741
fn af_identity(out: MutAfArray, ndims: c_uint, dims: *const DimT, afdtype: uint8_t) -> c_int;
4842
fn af_diag_create(out: MutAfArray, arr: AfArray, num: c_int) -> c_int;
4943
fn af_diag_extract(out: MutAfArray, arr: AfArray, num: c_int) -> c_int;
@@ -250,55 +244,28 @@ pub fn iota<T: HasAfEnum>(dims: Dim4, tdims: Dim4) -> Array {
250244
}
251245
}
252246

253-
/// Set seed for random number generation
254-
pub fn set_seed(seed: u64) {
255-
unsafe {
256-
let err_val = af_set_seed(seed as Uintl);
257-
HANDLE_ERROR(AfError::from(err_val));
258-
}
259-
}
260-
261-
/// Get the seed of random number generator
247+
/// Create an identity array with 1's in diagonal
248+
///
249+
/// # Parameters
250+
///
251+
/// - `dims` is the output Array dimensions
252+
///
253+
/// # Return Values
254+
///
255+
/// Identity matrix
262256
#[allow(unused_mut)]
263-
pub fn get_seed() -> u64 {
257+
pub fn identity<T: HasAfEnum>(dims: Dim4) -> Array {
264258
unsafe {
265-
let mut temp: u64 = 0;
266-
let err_val = af_get_seed(&mut temp as *mut Uintl);
259+
let aftype = T::get_af_dtype();
260+
let mut temp: i64 = 0;
261+
let err_val = af_identity(&mut temp as MutAfArray,
262+
dims.ndims() as c_uint, dims.get().as_ptr() as *const DimT,
263+
aftype as uint8_t);
267264
HANDLE_ERROR(AfError::from(err_val));
268-
temp
265+
Array::from(temp)
269266
}
270267
}
271268

272-
macro_rules! data_gen_def {
273-
($doc_str: expr, $fn_name:ident, $ffi_name: ident) => (
274-
#[doc=$doc_str]
275-
///
276-
///# Parameters
277-
///
278-
/// - `dims` is the output dimensions
279-
///
280-
///# Return Values
281-
///
282-
/// An Array with modified data.
283-
#[allow(unused_mut)]
284-
pub fn $fn_name<T: HasAfEnum>(dims: Dim4) -> Array {
285-
unsafe {
286-
let aftype = T::get_af_dtype();
287-
let mut temp: i64 = 0;
288-
let err_val = $ffi_name(&mut temp as MutAfArray,
289-
dims.ndims() as c_uint, dims.get().as_ptr() as *const DimT,
290-
aftype as uint8_t);
291-
HANDLE_ERROR(AfError::from(err_val));
292-
Array::from(temp)
293-
}
294-
}
295-
)
296-
}
297-
298-
data_gen_def!("Create random numbers from uniform distribution", randu, af_randu);
299-
data_gen_def!("Create random numbers from normal distribution", randn, af_randn);
300-
data_gen_def!("Create an identity array with 1's in diagonal", identity, af_identity);
301-
302269
/// Create a diagonal matrix
303270
///
304271
/// # Parameters

src/defines.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,3 +380,20 @@ pub enum BinaryOp {
380380
MIN = 2,
381381
MAX = 3
382382
}
383+
384+
/// Random engine types
385+
#[repr(C)]
386+
#[derive(Clone, Copy, Debug, PartialEq)]
387+
pub enum RandomEngineType {
388+
///Philox variant with N=4, W=32 and Rounds=10
389+
PHILOX_4X32_10 = 100,
390+
///Threefry variant with N=2, W=32 and Rounds=16
391+
THREEFRY_2X32_16 = 200,
392+
///Mersenne variant with MEXP = 11213
393+
MERSENNE_GP11213 = 300
394+
}
395+
396+
pub const PHILOX : RandomEngineType = RandomEngineType::PHILOX_4X32_10;
397+
pub const THREEFRY : RandomEngineType = RandomEngineType::THREEFRY_2X32_16;
398+
pub const MERSENNE : RandomEngineType = RandomEngineType::MERSENNE_GP11213;
399+
pub const DEFAULT_RANDOM_ENGINE : RandomEngineType = PHILOX;

src/lib.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ pub use blas::{matmul, dot, transpose, transpose_inplace};
3333
mod blas;
3434

3535
pub use data::{constant, range, iota};
36-
pub use data::{set_seed, get_seed, randu, randn};
3736
pub use data::{identity, diag_create, diag_extract, lower, upper};
3837
pub use data::{join, join_many, tile};
3938
pub use data::{reorder, shift, moddims, flat, flip};
@@ -47,7 +46,8 @@ mod device;
4746
pub use defines::{DType, AfError, Backend, ColorMap, YCCStd, HomographyType};
4847
pub use defines::{InterpType, BorderType, MatchType, NormType};
4948
pub use defines::{Connectivity, ConvMode, ConvDomain, ColorSpace, MatProp};
50-
pub use defines::{MarkerType, MomentType, SparseFormat, BinaryOp};
49+
pub use defines::{MarkerType, MomentType, SparseFormat, BinaryOp, RandomEngineType};
50+
pub use defines::{PHILOX, THREEFRY, MERSENNE, DEFAULT_RANDOM_ENGINE};
5151
mod defines;
5252

5353
pub use dim4::Dim4;
@@ -83,6 +83,11 @@ mod lapack;
8383
mod macros;
8484
mod num;
8585

86+
pub use random::RandomEngine;
87+
pub use random::{set_seed, get_seed, randu, randn, random_uniform, random_normal};
88+
pub use random::{get_default_random_engine, set_default_random_engine_type};
89+
mod random;
90+
8691
pub use signal::{approx1, approx2, set_fft_plan_cache_size};
8792
pub use signal::{fft, fft2, fft3, ifft, ifft2, ifft3};
8893
pub use signal::{fft_r2c, fft2_r2c, fft3_r2c, fft_c2r, fft2_c2r, fft3_c2r};

src/random/mod.rs

Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
extern crate libc;
2+
3+
use array::Array;
4+
use dim4::Dim4;
5+
use defines::{AfError, RandomEngineType};
6+
use error::HANDLE_ERROR;
7+
use self::libc::{uint8_t, c_int, c_uint, c_ulong};
8+
use util::HasAfEnum;
9+
10+
type MutAfArray = *mut self::libc::c_longlong;
11+
type AfArray = self::libc::c_longlong;
12+
type MutRandEngine = *mut self::libc::c_longlong;
13+
type RandEngine = self::libc::c_longlong;
14+
type DimT = self::libc::c_longlong;
15+
type Intl = self::libc::c_longlong;
16+
type Uintl = self::libc::c_ulonglong;
17+
18+
#[allow(dead_code)]
19+
extern {
20+
fn af_set_seed(seed: Uintl) -> c_int;
21+
fn af_get_seed(seed: *mut Uintl) -> c_int;
22+
23+
fn af_randu(out: MutAfArray, ndims: c_uint, dims: *const DimT, afdtype: uint8_t) -> c_int;
24+
fn af_randn(out: MutAfArray, ndims: c_uint, dims: *const DimT, afdtype: uint8_t) -> c_int;
25+
26+
fn af_create_random_engine(engine: MutRandEngine, rtype: uint8_t, seed: Uintl) -> c_int;
27+
fn af_retain_random_engine(engine: MutRandEngine, inputEngine: RandEngine) -> c_int;
28+
fn af_random_engine_set_type(engine: MutRandEngine, rtpye: uint8_t) -> c_int;
29+
fn af_random_engine_get_type(rtype: *mut uint8_t, engine: RandEngine) -> c_int;
30+
fn af_random_engine_set_seed(engine: MutRandEngine, seed: Uintl) -> c_int;
31+
fn af_random_engine_get_seed(seed: *mut Uintl, engine: RandEngine) -> c_int;
32+
fn af_release_random_engine(engine: RandEngine) -> c_int;
33+
34+
fn af_get_default_random_engine(engine: MutRandEngine) -> c_int;
35+
fn af_set_default_random_engine_type(rtype: uint8_t) -> c_int;
36+
37+
fn af_random_uniform(out: MutAfArray, ndims: c_uint, dims: *const DimT,
38+
aftype: uint8_t, engine: RandEngine) -> c_int;
39+
fn af_random_normal(out: MutAfArray, ndims: c_uint, dims: *const DimT,
40+
aftype: uint8_t, engine: RandEngine) -> c_int;
41+
}
42+
43+
/// Set seed for random number generation
44+
pub fn set_seed(seed: u64) {
45+
unsafe {
46+
let err_val = af_set_seed(seed as Uintl);
47+
HANDLE_ERROR(AfError::from(err_val));
48+
}
49+
}
50+
51+
/// Get the seed of random number generator
52+
#[allow(unused_mut)]
53+
pub fn get_seed() -> u64 {
54+
unsafe {
55+
let mut temp: u64 = 0;
56+
let err_val = af_get_seed(&mut temp as *mut Uintl);
57+
HANDLE_ERROR(AfError::from(err_val));
58+
temp
59+
}
60+
}
61+
62+
macro_rules! data_gen_def {
63+
($doc_str: expr, $fn_name:ident, $ffi_name: ident) => (
64+
#[doc=$doc_str]
65+
///
66+
///# Parameters
67+
///
68+
/// - `dims` is the output dimensions
69+
///
70+
///# Return Values
71+
///
72+
/// An Array with random values.
73+
#[allow(unused_mut)]
74+
pub fn $fn_name<T: HasAfEnum>(dims: Dim4) -> Array {
75+
unsafe {
76+
let aftype = T::get_af_dtype();
77+
let mut temp: i64 = 0;
78+
let err_val = $ffi_name(&mut temp as MutAfArray,
79+
dims.ndims() as c_uint, dims.get().as_ptr() as *const DimT,
80+
aftype as uint8_t);
81+
HANDLE_ERROR(AfError::from(err_val));
82+
Array::from(temp)
83+
}
84+
}
85+
)
86+
}
87+
88+
data_gen_def!("Create random numbers from uniform distribution", randu, af_randu);
89+
data_gen_def!("Create random numbers from normal distribution", randn, af_randn);
90+
91+
/// Random number generator engine
92+
///
93+
/// This is a wrapper for ArrayFire's native random number generator engine.
94+
pub struct RandomEngine {
95+
handle: i64,
96+
}
97+
98+
/// Used for creating RandomEngine object from native resource id
99+
impl From<i64> for RandomEngine {
100+
fn from(t: i64) -> RandomEngine {
101+
RandomEngine {handle: t}
102+
}
103+
}
104+
105+
impl RandomEngine {
106+
/// Create a new random engine object
107+
///
108+
/// # Parameters
109+
///
110+
/// - `rengine` can be value of [RandomEngineType](./enum.RandomEngineType.html) enum.
111+
/// - `seed` is the initial seed value
112+
///
113+
/// # Return Values
114+
///
115+
/// A object of type RandomEngine
116+
pub fn new(rengine: RandomEngineType, seed: Option<u64>) -> RandomEngine {
117+
unsafe {
118+
let mut temp: i64 = 0;
119+
let err_val = af_create_random_engine(&mut temp as MutRandEngine, rengine as uint8_t,
120+
match seed {Some(s) => s, None => 0} as c_ulong);
121+
HANDLE_ERROR(AfError::from(err_val));
122+
RandomEngine::from(temp)
123+
}
124+
}
125+
126+
/// Get random engine type
127+
pub fn get_type(&self) -> RandomEngineType {
128+
unsafe {
129+
let mut temp: u8 = 0;
130+
let err_val = af_random_engine_get_type(&mut temp as *mut uint8_t,
131+
self.handle as RandEngine);
132+
HANDLE_ERROR(AfError::from(err_val));
133+
RandomEngineType::from(temp as i32)
134+
}
135+
}
136+
137+
/// Get random engine type
138+
pub fn set_type(&mut self, engine_type: RandomEngineType) {
139+
unsafe {
140+
let err_val = af_random_engine_set_type(&mut self.handle as MutRandEngine,
141+
engine_type as uint8_t);
142+
HANDLE_ERROR(AfError::from(err_val));
143+
}
144+
}
145+
146+
/// Set seed for random engine
147+
pub fn set_seed(&mut self, seed: u64) {
148+
unsafe {
149+
let err_val = af_random_engine_set_seed(&mut self.handle as MutRandEngine,
150+
seed as Uintl);
151+
HANDLE_ERROR(AfError::from(err_val));
152+
}
153+
}
154+
155+
/// Get seed of the random engine
156+
pub fn get_seed(&self) -> u64 {
157+
unsafe {
158+
let mut seed: u64 = 0;
159+
let err_val = af_random_engine_get_seed(&mut seed as *mut Uintl, self.handle as RandEngine);
160+
HANDLE_ERROR(AfError::from(err_val));
161+
seed
162+
}
163+
}
164+
165+
/// Returns the native FFI handle for Rust object `RandomEngine`
166+
pub fn get(&self) -> i64 {
167+
self.handle
168+
}
169+
}
170+
171+
/// Increment reference count of RandomEngine's native resource
172+
impl Clone for RandomEngine {
173+
fn clone(&self) -> RandomEngine {
174+
unsafe {
175+
let mut temp: i64 = 0;
176+
let err_val = af_retain_random_engine(&mut temp as MutRandEngine, self.handle as RandEngine);
177+
HANDLE_ERROR(AfError::from(err_val));
178+
RandomEngine::from(temp)
179+
}
180+
}
181+
}
182+
183+
/// Free RandomEngine's native resource
184+
impl Drop for RandomEngine {
185+
fn drop(&mut self) {
186+
unsafe {
187+
let err_val = af_release_random_engine(self.handle as RandEngine);
188+
HANDLE_ERROR(AfError::from(err_val));
189+
}
190+
}
191+
}
192+
193+
/// Get default random engine
194+
pub fn get_default_random_engine() -> RandomEngine {
195+
unsafe {
196+
let mut temp : i64 = 0;
197+
let mut err_val = af_get_default_random_engine(&mut temp as MutRandEngine);
198+
HANDLE_ERROR(AfError::from(err_val));
199+
let mut handle : i64 = 0;
200+
err_val = af_retain_random_engine(&mut handle as MutRandEngine, temp as RandEngine);
201+
HANDLE_ERROR(AfError::from(err_val));
202+
RandomEngine::from(handle)
203+
}
204+
}
205+
206+
/// Set the random engine type for default random number generator
207+
///
208+
/// # Parameters
209+
///
210+
/// - `rtype` can take one of the values of enum [RandomEngineType](./enum.RandomEngineType.html)
211+
pub fn set_default_random_engine_type(rtype: RandomEngineType) {
212+
unsafe {
213+
let err_val = af_set_default_random_engine_type(rtype as uint8_t);
214+
HANDLE_ERROR(AfError::from(err_val));
215+
}
216+
}
217+
218+
/// Generate array of uniform numbers using a random engine
219+
///
220+
/// # Parameters
221+
///
222+
/// - `dims` is output array dimensions
223+
/// - `engine` is an object of type [RandomEngine](./struct.RandomEngine.html)
224+
///
225+
/// # Return Values
226+
///
227+
/// An Array with uniform numbers generated using random engine
228+
pub fn random_uniform<T: HasAfEnum>(dims: Dim4, engine: RandomEngine) -> Array {
229+
unsafe {
230+
let aftype = T::get_af_dtype();
231+
let mut temp : i64 = 0;
232+
let err_val = af_random_uniform(&mut temp as MutAfArray,
233+
dims.ndims() as c_uint, dims.get().as_ptr() as *const DimT,
234+
aftype as uint8_t, engine.get() as RandEngine);
235+
HANDLE_ERROR(AfError::from(err_val));
236+
Array::from(temp)
237+
}
238+
}
239+
240+
/// Generate array of normal numbers using a random engine
241+
///
242+
/// # Parameters
243+
///
244+
/// - `dims` is output array dimensions
245+
/// - `engine` is an object of type [RandomEngine](./struct.RandomEngine.html)
246+
///
247+
/// # Return Values
248+
///
249+
/// An Array with normal numbers generated using random engine
250+
pub fn random_normal<T: HasAfEnum>(dims: Dim4, engine: RandomEngine) -> Array {
251+
unsafe {
252+
let aftype = T::get_af_dtype();
253+
let mut temp : i64 = 0;
254+
let err_val = af_random_normal(&mut temp as MutAfArray,
255+
dims.ndims() as c_uint, dims.get().as_ptr() as *const DimT,
256+
aftype as uint8_t, engine.get() as RandEngine);
257+
HANDLE_ERROR(AfError::from(err_val));
258+
Array::from(temp)
259+
}
260+
}

0 commit comments

Comments
 (0)