@@ -53,6 +53,36 @@ static inline double gauss(aug_state* state)
5353 }
5454}
5555
56+ static inline float gauss_float (aug_state * state )
57+ {
58+ if (state -> has_gauss_float )
59+ {
60+ const float temp = state -> gauss_float ;
61+ state -> has_gauss_float = false;
62+ state -> gauss_float = 0.0f ;
63+ return temp ;
64+ }
65+ else
66+ {
67+ float f , x1 , x2 , r2 ;
68+
69+ do {
70+ x1 = 2.0f * random_float (state ) - 1.0f ;
71+ x2 = 2.0f * random_float (state ) - 1.0f ;
72+ r2 = x1 * x1 + x2 * x2 ;
73+ }
74+ while (r2 >= 1.0 || r2 == 0.0 );
75+
76+ /* Box-Muller transform */
77+ f = sqrtf (-2.0f * logf (r2 )/r2 );
78+ /* Keep for next call */
79+ state -> gauss_float = f * x1 ;
80+ state -> has_gauss_float = true;
81+ return f * x2 ;
82+ }
83+ }
84+
85+
5686/*
5787* Julia implementation of Ziggurat algo
5888* MIT license
@@ -92,6 +122,116 @@ static inline double gauss_zig_julia(aug_state* state)
92122 }
93123}
94124
125+
126+ static inline double standard_gamma (aug_state * state , double shape )
127+ {
128+ double b , c ;
129+ double U , V , X , Y ;
130+
131+ if (shape == 1.0 )
132+ {
133+ return standard_exponential (state );
134+ }
135+ else if (shape < 1.0 )
136+ {
137+ for (;;)
138+ {
139+ U = random_double (state );
140+ V = standard_exponential (state );
141+ if (U <= 1.0 - shape )
142+ {
143+ X = pow (U , 1. /shape );
144+ if (X <= V )
145+ {
146+ return X ;
147+ }
148+ }
149+ else
150+ {
151+ Y = - log ((1 - U )/shape );
152+ X = pow (1.0 - shape + shape * Y , 1. /shape );
153+ if (X <= (V + Y ))
154+ {
155+ return X ;
156+ }
157+ }
158+ }
159+ }
160+ else
161+ {
162+ b = shape - 1. /3. ;
163+ c = 1. /sqrt (9 * b );
164+ for (;;)
165+ {
166+ do
167+ {
168+ X = gauss (state );
169+ V = 1.0 + c * X ;
170+ } while (V <= 0.0 );
171+
172+ V = V * V * V ;
173+ U = random_double (state );
174+ if (U < 1.0 - 0.0331 * (X * X )* (X * X )) return (b * V );
175+ if (log (U ) < 0.5 * X * X + b * (1. - V + log (V ))) return (b * V );
176+ }
177+ }
178+ }
179+
180+ static inline float standard_gamma_float (aug_state * state , float shape )
181+ {
182+ float b , c ;
183+ float U , V , X , Y ;
184+
185+ if (shape == 1.0f )
186+ {
187+ return standard_exponential_float (state );
188+ }
189+ else if (shape < 1.0f )
190+ {
191+ for (;;)
192+ {
193+ U = random_float (state );
194+ V = standard_exponential_float (state );
195+ if (U <= 1.0f - shape )
196+ {
197+ X = powf (U , 1.0f /shape );
198+ if (X <= V )
199+ {
200+ return X ;
201+ }
202+ }
203+ else
204+ {
205+ Y = - logf ((1.0f - U )/shape );
206+ X = powf (1.0f - shape + shape * Y , 1.0f /shape );
207+ if (X <= (V + Y ))
208+ {
209+ return X ;
210+ }
211+ }
212+ }
213+ }
214+ else
215+ {
216+ b = shape - 1.0f /3.0f ;
217+ c = 1.0f / sqrtf (9.0f * b );
218+ for (;;)
219+ {
220+ do
221+ {
222+ X = gauss_float (state );
223+ V = 1.0f + c * X ;
224+ } while (V <= 0.0f );
225+
226+ V = V * V * V ;
227+ U = random_float (state );
228+ if (U < 1.0f - 0.0331f * (X * X )* (X * X )) return (b * V );
229+ if (logf (U ) < 0.5f * X * X + b * (1.0f - V + logf (V ))) return (b * V );
230+ }
231+ }
232+ }
233+
234+
95235/*
96236 *
97237 * RNGs for use in other code
@@ -186,59 +326,17 @@ void random_gauss_fill(aug_state* state, npy_intp count, double *out) {
186326 }
187327}
188328
329+ void random_gauss_fill_float (aug_state * state , npy_intp count , float * out ) {
189330
190- double random_standard_gamma (aug_state * state , double shape )
191- {
192- double b , c ;
193- double U , V , X , Y ;
194-
195- if (shape == 1.0 )
196- {
197- return random_standard_exponential (state );
198- }
199- else if (shape < 1.0 )
200- {
201- for (;;)
202- {
203- U = random_double (state );
204- V = random_standard_exponential (state );
205- if (U <= 1.0 - shape )
206- {
207- X = pow (U , 1. /shape );
208- if (X <= V )
209- {
210- return X ;
211- }
212- }
213- else
214- {
215- Y = - log ((1 - U )/shape );
216- X = pow (1.0 - shape + shape * Y , 1. /shape );
217- if (X <= (V + Y ))
218- {
219- return X ;
220- }
221- }
222- }
331+ npy_intp i ;
332+ for (i = 0 ; i < count ; i ++ ) {
333+ out [i ] = gauss_float (state );
223334 }
224- else
225- {
226- b = shape - 1. /3. ;
227- c = 1. /sqrt (9 * b );
228- for (;;)
229- {
230- do
231- {
232- X = random_gauss (state );
233- V = 1.0 + c * X ;
234- } while (V <= 0.0 );
335+ }
235336
236- V = V * V * V ;
237- U = random_double (state );
238- if (U < 1.0 - 0.0331 * (X * X )* (X * X )) return (b * V );
239- if (log (U ) < 0.5 * X * X + b * (1. - V + log (V ))) return (b * V );
240- }
241- }
337+ double random_standard_gamma (aug_state * state , double shape )
338+ {
339+ return standard_gamma (state , shape );
242340}
243341
244342
@@ -311,7 +409,12 @@ double random_uniform(aug_state *state, double lower, double range)
311409
312410double random_gamma (aug_state * state , double shape , double scale )
313411{
314- return scale * random_standard_gamma (state , shape );
412+ return scale * standard_gamma (state , shape );
413+ }
414+
415+ float random_gamma_float (aug_state * state , float shape , float scale )
416+ {
417+ return scale * standard_gamma_float (state , shape );
315418}
316419
317420double random_beta (aug_state * state , double a , double b )
0 commit comments