@@ -4,12 +4,9 @@ DEF RS_RNG_JUMPABLE = 1
44cdef extern from " distributions.h" :
55
66 cdef struct s_mrg32k3a_state:
7- int64_t s10
8- int64_t s11
9- int64_t s12
10- int64_t s20
11- int64_t s21
12- int64_t s22
7+ int64_t s1[3 ]
8+ int64_t s2[3 ]
9+ int loc
1310
1411 ctypedef s_mrg32k3a_state mrg32k3a_state
1512
@@ -31,16 +28,18 @@ ctypedef mrg32k3a_state rng_t
3128ctypedef uint64_t rng_state_t
3229
3330cdef object _get_state(aug_state state):
34- return (state.rng.s10, state.rng.s11, state.rng.s12,
35- state.rng.s20, state.rng.s21, state.rng.s22)
31+ return (state.rng.s1[0 ], state.rng.s1[1 ], state.rng.s1[2 ],
32+ state.rng.s2[0 ], state.rng.s2[1 ], state.rng.s2[2 ],
33+ state.rng.loc)
3634
3735cdef object _set_state(aug_state * state, object state_info):
38- state.rng.s10 = state_info[0 ]
39- state.rng.s11 = state_info[1 ]
40- state.rng.s12 = state_info[2 ]
41- state.rng.s20 = state_info[3 ]
42- state.rng.s21 = state_info[4 ]
43- state.rng.s22 = state_info[5 ]
36+ state.rng.s1[0 ] = state_info[0 ]
37+ state.rng.s1[1 ] = state_info[1 ]
38+ state.rng.s1[2 ] = state_info[2 ]
39+ state.rng.s2[0 ] = state_info[3 ]
40+ state.rng.s2[1 ] = state_info[4 ]
41+ state.rng.s2[2 ] = state_info[5 ]
42+ state.rng.loc = state_info[6 ]
4443
4544cdef object matrix_power_127(x, m):
4645 n = x.shape[0 ]
@@ -68,21 +67,39 @@ A2_127 = matrix_power_127(A2p, m2)
6867
6968cdef void jump_state(aug_state* state):
7069 # vectors s1 and s2
71- s1 = np.array([state.rng.s10,state.rng.s11,state.rng.s12], dtype = np.uint64)
72- s2 = np.array([state.rng.s20,state.rng.s21,state.rng.s22], dtype = np.uint64)
70+ loc = state.rng.loc
71+
72+ if loc == 0 :
73+ loc_m1 = 2
74+ loc_m2 = 1
75+ elif loc == 1 :
76+ loc_m1 = 0
77+ loc_m2 = 2
78+ else :
79+ loc_m1 = 1
80+ loc_m2 = 0
81+
82+ s1 = np.array([state.rng.s1[loc_m2],
83+ state.rng.s1[loc_m1],
84+ state.rng.s1[loc]], dtype = np.uint64)
85+ s2 = np.array([state.rng.s2[loc_m2],
86+ state.rng.s2[loc_m1],
87+ state.rng.s2[loc]], dtype = np.uint64)
7388
7489 # Advance the state
7590 s1 = np.mod(A1_127.dot(s1), m1)
7691 s2 = np.mod(A1_127.dot(s2), m2)
7792
7893 # Restore state
79- state.rng.s10 = s1[0 ]
80- state.rng.s11 = s1[1 ]
81- state.rng.s12 = s1[2 ]
94+ state.rng.s1[ 0 ] = s1[0 ]
95+ state.rng.s1[ 1 ] = s1[1 ]
96+ state.rng.s1[ 2 ] = s1[2 ]
8297
83- state.rng.s20 = s2[0 ]
84- state.rng.s21 = s2[1 ]
85- state.rng.s22 = s2[2 ]
98+ state.rng.s2[0 ] = s2[0 ]
99+ state.rng.s2[1 ] = s2[1 ]
100+ state.rng.s2[2 ] = s2[2 ]
101+
102+ state.rng.loc = 2
86103
87104DEF CLASS_DOCSTRING = """
88105RandomState(seed=None)
0 commit comments