|
33 | 33 |
|
34 | 34 | from pymc.blocking import DictToArrayBijection |
35 | 35 | from pymc.exceptions import SamplingError |
36 | | -from pymc.util import CustomProgress, default_progress_theme |
| 36 | +from pymc.util import ( |
| 37 | + CustomProgress, |
| 38 | + RandomGeneratorState, |
| 39 | + default_progress_theme, |
| 40 | + get_state_from_generator, |
| 41 | + random_generator_from_state, |
| 42 | +) |
37 | 43 |
|
38 | 44 | logger = logging.getLogger(__name__) |
39 | 45 |
|
@@ -96,13 +102,12 @@ def __init__( |
96 | 102 | shared_point, |
97 | 103 | draws: int, |
98 | 104 | tune: int, |
99 | | - rng: np.random.Generator, |
100 | | - seed_seq: np.random.SeedSequence, |
| 105 | + rng_state: RandomGeneratorState, |
101 | 106 | blas_cores, |
102 | 107 | ): |
103 | 108 | # For some strange reason, spawn multiprocessing doesn't copy the rng |
104 | 109 | # seed sequence, so we have to rebuild it from scratch |
105 | | - rng = np.random.Generator(type(rng.bit_generator)(seed_seq)) |
| 110 | + rng = random_generator_from_state(rng_state) |
106 | 111 | self._msg_pipe = msg_pipe |
107 | 112 | self._step_method = step_method |
108 | 113 | self._step_method_is_pickled = step_method_is_pickled |
@@ -263,8 +268,7 @@ def __init__( |
263 | 268 | self._shared_point, |
264 | 269 | draws, |
265 | 270 | tune, |
266 | | - rng, |
267 | | - rng.bit_generator.seed_seq, |
| 271 | + get_state_from_generator(rng), |
268 | 272 | blas_cores, |
269 | 273 | ), |
270 | 274 | ) |
|
0 commit comments