1919from randomstate .prng .xorshift128 import xorshift128
2020from randomstate .prng .xoroshiro128plus import xoroshiro128plus
2121from randomstate .prng .dsfmt import dsfmt
22- from numpy .testing import assert_almost_equal , assert_equal , assert_raises , assert_
22+ from numpy .testing import assert_almost_equal , assert_equal , assert_raises , assert_ , assert_array_equal
2323
2424from nose import SkipTest
2525
@@ -84,6 +84,19 @@ def comp_state(state1, state2):
8484 return identical
8585
8686
87+ def warmup (rs , n = None ):
88+ if n is None :
89+ n = 11 + np .random .randint (0 , 20 )
90+ rs .standard_normal (n , method = 'bm' )
91+ rs .standard_normal (n , method = 'zig' )
92+ rs .standard_normal (n , method = 'bm' , dtype = np .float32 )
93+ rs .randint (0 , 2 ** 24 , n , dtype = np .uint64 )
94+ rs .randint (0 , 2 ** 48 , n , dtype = np .uint64 )
95+ rs .standard_gamma (11 , n )
96+ rs .random_sample (n , dtype = np .float64 )
97+ rs .random_sample (n , dtype = np .float32 )
98+
99+
87100class RNG (object ):
88101 @classmethod
89102 def _extra_setup (cls ):
@@ -121,7 +134,7 @@ def test_jump(self):
121134
122135 def test_random_raw (self ):
123136 assert_ (len (self .rs .random_raw (10 )) == 10 )
124- assert_ (self .rs .random_raw ((10 ,10 )).shape == (10 ,10 ))
137+ assert_ (self .rs .random_raw ((10 , 10 )).shape == (10 , 10 ))
125138
126139 def test_uniform (self ):
127140 r = self .rs .uniform (- 1.0 , 0.0 , size = 10 )
@@ -202,17 +215,17 @@ def test_reset_state_gauss(self):
202215 rs2 = self .mod .RandomState ()
203216 rs2 .set_state (state )
204217 n2 = rs2 .standard_normal (size = 10 )
205- assert_ (( n1 == n2 ). all () )
218+ assert_array_equal ( n1 , n2 )
206219
207220 def test_reset_state_uint32 (self ):
208221 rs = self .mod .RandomState (* self .seed )
209- rs .randint (0 , 2 ** 24 , dtype = np .uint32 )
222+ rs .randint (0 , 2 ** 24 , 120 , dtype = np .uint32 )
210223 state = rs .get_state ()
211- n1 = rs .randint (0 , 2 ** 24 , 10 , dtype = np .uint32 )
224+ n1 = rs .randint (0 , 2 ** 24 , 10 , dtype = np .uint32 )
212225 rs2 = self .mod .RandomState ()
213226 rs2 .set_state (state )
214- n2 = rs .randint (0 , 2 ** 24 , 10 , dtype = np .uint32 )
215- assert_ (( n1 == n2 ). all () )
227+ n2 = rs2 .randint (0 , 2 ** 24 , 10 , dtype = np .uint32 )
228+ assert_array_equal ( n1 , n2 )
216229
217230 def test_shuffle (self ):
218231 original = np .arange (200 , 0 , - 1 )
@@ -487,10 +500,10 @@ def test_seed_array(self):
487500 def test_seed_array_error (self ):
488501 if self .seed_vector_bits == 32 :
489502 dtype = np .uint32
490- out_of_bounds = 2 ** 32
503+ out_of_bounds = 2 ** 32
491504 else :
492505 dtype = np .uint64
493- out_of_bounds = 2 ** 64
506+ out_of_bounds = 2 ** 64
494507
495508 seed = - 1
496509 assert_raises (ValueError , self .rs .seed , seed )
@@ -504,6 +517,32 @@ def test_seed_array_error(self):
504517 seed = np .array ([1 , 2 , 3 , out_of_bounds ])
505518 assert_raises (ValueError , self .rs .seed , seed )
506519
520+ def test_uniform_float (self ):
521+ rs = self .mod .RandomState (12345 )
522+ warmup (rs )
523+ state = rs .get_state ()
524+ r1 = rs .random_sample (11 , dtype = np .float32 )
525+ rs2 = self .mod .RandomState ()
526+ warmup (rs2 )
527+ rs2 .set_state (state )
528+ r2 = rs2 .random_sample (11 , dtype = np .float32 )
529+ assert_array_equal (r1 , r2 )
530+ assert_equal (r1 .dtype , np .float32 )
531+ assert_ (comp_state (rs .get_state (), rs2 .get_state ()))
532+
533+ def test_normal_floats (self ):
534+ rs = self .mod .RandomState ()
535+ warmup (rs )
536+ state = rs .get_state ()
537+ r1 = rs .standard_normal (11 , method = 'bm' , dtype = np .float32 )
538+ rs2 = self .mod .RandomState ()
539+ warmup (rs2 )
540+ rs2 .set_state (state )
541+ r2 = rs2 .standard_normal (11 , method = 'bm' , dtype = np .float32 )
542+ assert_array_equal (r1 , r2 )
543+ assert_equal (r1 .dtype , np .float32 )
544+ assert_ (comp_state (rs .get_state (), rs2 .get_state ()))
545+
507546
508547class TestMT19937 (RNG ):
509548 @classmethod
@@ -642,9 +681,3 @@ def test_fallback(self):
642681 time .sleep (0.1 )
643682 e2 = entropy .random_entropy (source = 'fallback' )
644683 assert_ ((e1 != e2 ))
645-
646-
647- if __name__ == '__main__' :
648- import nose
649-
650- nose .run (argv = [__file__ , '-vv' ])
0 commit comments