|
9 | 9 | %include "import_flc.i" |
10 | 10 | %flc_add_header |
11 | 11 |
|
12 | | -/* ------------------------------------------------------------------------- |
13 | | - * Generator class definition |
14 | | - * ------------------------------------------------------------------------- */ |
15 | | - |
16 | 12 | %{ |
17 | 13 | #include <random> |
18 | 14 | %} |
19 | 15 |
|
20 | | -// TODO: define a CMake-configurable option for selecting the 32-bit twister |
21 | | -#if 0 |
22 | | -#define SWIG_MERSENNE_TWISTER mt19937 |
23 | | -#define SWIG_MERSENNE_RESULT_TYPE int32_t |
24 | | -#else |
25 | | -#define SWIG_MERSENNE_TWISTER mt19937_64 |
26 | | -#define SWIG_MERSENNE_RESULT_TYPE int64_t |
27 | | -#endif |
28 | | - |
29 | | -%rename(Engine) std::SWIG_MERSENNE_TWISTER; |
30 | | -%fortran_autofree_rvalue(std::SWIG_MERSENNE_TWISTER); |
| 16 | +/* ------------------------------------------------------------------------- |
| 17 | + * Macros |
| 18 | + * ------------------------------------------------------------------------- */ |
31 | 19 |
|
| 20 | +%define %flc_random_engine(NAME, GENERATOR, RESULT_TYPE) |
| 21 | +%fortran_autofree_rvalue(std::GENERATOR); |
32 | 22 | namespace std { |
33 | | -class SWIG_MERSENNE_TWISTER |
| 23 | + |
| 24 | +%rename(NAME) GENERATOR; |
| 25 | +%rename("next") GENERATOR::operator(); |
| 26 | + |
| 27 | +class GENERATOR |
34 | 28 | { |
35 | 29 | public: |
36 | | - typedef SWIG_MERSENNE_RESULT_TYPE result_type; |
| 30 | + typedef RESULT_TYPE result_type; |
37 | 31 |
|
38 | | - SWIG_MERSENNE_TWISTER(); |
39 | | - explicit SWIG_MERSENNE_TWISTER(result_type seed_value); |
| 32 | + GENERATOR(); |
| 33 | + explicit GENERATOR(result_type seed_value); |
40 | 34 | void seed(result_type seed_value); |
41 | 35 | void discard(unsigned long long count); |
| 36 | + result_type operator()(); |
42 | 37 | }; |
| 38 | + |
43 | 39 | } // namespace std |
| 40 | +%enddef |
44 | 41 |
|
45 | 42 | /* ------------------------------------------------------------------------- |
46 | 43 | * RNG distribution routines |
47 | | - * |
48 | | - * The generated subroutines will be called from Fortran like: |
49 | | - * |
50 | | - * call uniform_real_distribution(gen, -10, 10, fill_array) |
51 | 44 | * ------------------------------------------------------------------------- */ |
52 | 45 |
|
53 | | -%define %flc_random_distribution1(DISTNAME, TYPE, ARG1) |
54 | | -%inline { |
55 | | -static void DISTNAME(TYPE ARG1, |
56 | | - std::SWIG_MERSENNE_TWISTER& g, |
57 | | - TYPE *DATA, size_t DATASIZE) { |
58 | | - std::DISTNAME<TYPE> dist(ARG1); |
59 | | - TYPE *end = DATA + DATASIZE; |
60 | | - while (DATA != end) { |
61 | | - *DATA++ = dist(g); |
62 | | - } |
| 46 | +%{ |
| 47 | +template<class D, class G, class T> |
| 48 | +static inline void flc_generate(D dist, G& g, T* data, size_t size) { |
| 49 | + T* const end = data + size; |
| 50 | + while (data != end) { |
| 51 | + *data++ = dist(g); |
| 52 | + } |
63 | 53 | } |
| 54 | +%} |
| 55 | + |
| 56 | +%apply (const SWIGTYPE *DATA, size_t SIZE) { |
| 57 | + (const int32_t *WEIGHTS, size_t WEIGHTSIZE), |
| 58 | + (const int64_t *WEIGHTS, size_t WEIGHTSIZE) }; |
| 59 | + |
| 60 | +%inline %{ |
| 61 | +template<class T, class G> |
| 62 | +static void uniform_int_distribution(T left, T right, |
| 63 | + G& engine, T* DATA, size_t DATASIZE) { |
| 64 | + flc_generate(std::uniform_int_distribution<T>(left, right), |
| 65 | + engine, DATA, DATASIZE); |
64 | 66 | } |
65 | | -%enddef |
66 | | -%define %flc_random_distribution2(DISTNAME, TYPE, ARG1, ARG2) |
67 | | -%inline { |
68 | | -static void DISTNAME(TYPE ARG1, TYPE ARG2, |
69 | | - std::SWIG_MERSENNE_TWISTER& g, |
70 | | - TYPE *DATA, size_t DATASIZE) { |
71 | | - std::DISTNAME<TYPE> dist(ARG1, ARG2); |
72 | | - TYPE *end = DATA + DATASIZE; |
73 | | - while (DATA != end) { |
74 | | - *DATA++ = dist(g); |
75 | | - } |
| 67 | + |
| 68 | +template<class T, class G> |
| 69 | +static void uniform_real_distribution(T left, T right, |
| 70 | + G& engine, T* DATA, size_t DATASIZE) { |
| 71 | + flc_generate(std::uniform_real_distribution<T>(left, right), |
| 72 | + engine, DATA, DATASIZE); |
| 73 | +} |
| 74 | + |
| 75 | +template<class T, class G> |
| 76 | +static void normal_distribution(T mean, T stddev, |
| 77 | + G& engine, T* DATA, size_t DATASIZE) { |
| 78 | + flc_generate(std::normal_distribution<T>(mean, stddev), |
| 79 | + engine, DATA, DATASIZE); |
76 | 80 | } |
| 81 | + |
| 82 | +template<class T, class G> |
| 83 | +static void discrete_distribution(const T* WEIGHTS, size_t WEIGHTSIZE, |
| 84 | + G& engine, T* DATA, size_t DATASIZE) { |
| 85 | + std::discrete_distribution<T> dist(WEIGHTS, WEIGHTS + WEIGHTSIZE); |
| 86 | + T* const end = DATA + DATASIZE; |
| 87 | + while (DATA != end) { |
| 88 | + *DATA++ = dist(engine) + 1; // Note: transform to Fortran 1-offset |
| 89 | + } |
77 | 90 | } |
| 91 | +%} |
| 92 | + |
| 93 | +%define %flc_distribution(NAME, STDENGINE, TYPE) |
| 94 | +%template(NAME##_distribution) NAME##_distribution< TYPE, std::STDENGINE >; |
78 | 95 | %enddef |
79 | 96 |
|
80 | | -// Uniform distributions |
81 | | -%flc_random_distribution2(uniform_int_distribution, int32_t, left, right) |
82 | | -%flc_random_distribution2(uniform_int_distribution, int64_t, left, right) |
83 | | -%flc_random_distribution2(uniform_real_distribution, double, left, right) |
| 97 | +// Engines |
| 98 | +%flc_random_engine(MersenneEngine4, mt19937, int32_t) |
| 99 | +%flc_random_engine(MersenneEngine8, mt19937_64, int64_t) |
| 100 | + |
| 101 | +#define FLC_DEFAULT_ENGINE mt19937 |
| 102 | +%flc_distribution(uniform_int, FLC_DEFAULT_ENGINE, int32_t) |
| 103 | +%flc_distribution(uniform_int, FLC_DEFAULT_ENGINE, int64_t) |
| 104 | +%flc_distribution(uniform_real, FLC_DEFAULT_ENGINE, double) |
| 105 | + |
| 106 | +%flc_distribution(normal, FLC_DEFAULT_ENGINE, double) |
84 | 107 |
|
85 | | -// Gaussian distribution |
86 | | -%flc_random_distribution1(normal_distribution, double, mean) |
87 | | -%flc_random_distribution2(normal_distribution, double, mean, stddev) |
| 108 | +// Discrete sampling distribution |
| 109 | +%flc_distribution(discrete, FLC_DEFAULT_ENGINE, int32_t) |
| 110 | +%flc_distribution(discrete, FLC_DEFAULT_ENGINE, int64_t) |
0 commit comments