@@ -105,14 +105,24 @@ def jax_funcify_RandomVariable(op: ptr.RandomVariable, node, **kwargs):
105105 assert_size_argument_jax_compatible (node )
106106
107107 def sample_fn (rng , size , * parameters ):
108- return jax_sample_fn (op , node = node )(rng , size , out_dtype , * parameters )
108+ rng_key = rng ["jax_state" ]
109+ rng_key , sampling_key = jax .random .split (rng_key , 2 )
110+ rng ["jax_state" ] = rng_key
111+ sample = jax_sample_fn (op , node = node )(
112+ sampling_key , size , out_dtype , * parameters
113+ )
114+ return (rng , sample )
109115
110116 else :
111117
112118 def sample_fn (rng , size , * parameters ):
113- return jax_sample_fn (op , node = node )(
114- rng , static_size , out_dtype , * parameters
119+ rng_key = rng ["jax_state" ]
120+ rng_key , sampling_key = jax .random .split (rng_key , 2 )
121+ rng ["jax_state" ] = rng_key
122+ sample = jax_sample_fn (op , node = node )(
123+ sampling_key , static_size , out_dtype , * parameters
115124 )
125+ return (rng , sample )
116126
117127 return sample_fn
118128
@@ -133,12 +143,9 @@ def jax_sample_fn_generic(op, node):
133143 name = op .name
134144 jax_op = getattr (jax .random , name )
135145
136- def sample_fn (rng , size , dtype , * parameters ):
137- rng_key = rng ["jax_state" ]
138- rng_key , sampling_key = jax .random .split (rng_key , 2 )
139- sample = jax_op (sampling_key , * parameters , shape = size , dtype = dtype )
140- rng ["jax_state" ] = rng_key
141- return (rng , sample )
146+ def sample_fn (rng_key , size , dtype , * parameters ):
147+ sample = jax_op (rng_key , * parameters , shape = size , dtype = dtype )
148+ return sample
142149
143150 return sample_fn
144151
@@ -159,29 +166,23 @@ def jax_sample_fn_loc_scale(op, node):
159166 name = op .name
160167 jax_op = getattr (jax .random , name )
161168
162- def sample_fn (rng , size , dtype , * parameters ):
163- rng_key = rng ["jax_state" ]
164- rng_key , sampling_key = jax .random .split (rng_key , 2 )
169+ def sample_fn (rng_key , size , dtype , * parameters ):
165170 loc , scale = parameters
166171 if size is None :
167172 size = jax .numpy .broadcast_arrays (loc , scale )[0 ].shape
168- sample = loc + jax_op (sampling_key , size , dtype ) * scale
169- rng ["jax_state" ] = rng_key
170- return (rng , sample )
173+ sample = loc + jax_op (rng_key , size , dtype ) * scale
174+ return sample
171175
172176 return sample_fn
173177
174178
175179@jax_sample_fn .register (ptr .MvNormalRV )
176180def jax_sample_mvnormal (op , node ):
177- def sample_fn (rng , size , dtype , mean , cov ):
178- rng_key = rng ["jax_state" ]
179- rng_key , sampling_key = jax .random .split (rng_key , 2 )
181+ def sample_fn (rng_key , size , dtype , mean , cov ):
180182 sample = jax .random .multivariate_normal (
181- sampling_key , mean , cov , shape = size , dtype = dtype , method = op .method
183+ rng_key , mean , cov , shape = size , dtype = dtype , method = op .method
182184 )
183- rng ["jax_state" ] = rng_key
184- return (rng , sample )
185+ return sample
185186
186187 return sample_fn
187188
@@ -191,12 +192,9 @@ def jax_sample_fn_bernoulli(op, node):
191192 """JAX implementation of `BernoulliRV`."""
192193
193194 # We need a separate dispatch, because there is no dtype argument for Bernoulli in JAX
194- def sample_fn (rng , size , dtype , p ):
195- rng_key = rng ["jax_state" ]
196- rng_key , sampling_key = jax .random .split (rng_key , 2 )
197- sample = jax .random .bernoulli (sampling_key , p , shape = size )
198- rng ["jax_state" ] = rng_key
199- return (rng , sample )
195+ def sample_fn (rng_key , size , dtype , p ):
196+ sample = jax .random .bernoulli (rng_key , p , shape = size )
197+ return sample
200198
201199 return sample_fn
202200
@@ -206,14 +204,10 @@ def jax_sample_fn_categorical(op, node):
206204 """JAX implementation of `CategoricalRV`."""
207205
208206 # We need a separate dispatch because Categorical expects logits in JAX
209- def sample_fn (rng , size , dtype , p ):
210- rng_key = rng ["jax_state" ]
211- rng_key , sampling_key = jax .random .split (rng_key , 2 )
212-
207+ def sample_fn (rng_key , size , dtype , p ):
213208 logits = jax .scipy .special .logit (p )
214- sample = jax .random .categorical (sampling_key , logits = logits , shape = size )
215- rng ["jax_state" ] = rng_key
216- return (rng , sample )
209+ sample = jax .random .categorical (rng_key , logits = logits , shape = size )
210+ return sample
217211
218212 return sample_fn
219213
@@ -233,15 +227,10 @@ def jax_sample_fn_uniform(op, node):
233227 name = "randint"
234228 jax_op = getattr (jax .random , name )
235229
236- def sample_fn (rng , size , dtype , * parameters ):
237- rng_key = rng ["jax_state" ]
238- rng_key , sampling_key = jax .random .split (rng_key , 2 )
230+ def sample_fn (rng_key , size , dtype , * parameters ):
239231 minval , maxval = parameters
240- sample = jax_op (
241- sampling_key , shape = size , dtype = dtype , minval = minval , maxval = maxval
242- )
243- rng ["jax_state" ] = rng_key
244- return (rng , sample )
232+ sample = jax_op (rng_key , shape = size , dtype = dtype , minval = minval , maxval = maxval )
233+ return sample
245234
246235 return sample_fn
247236
@@ -258,14 +247,11 @@ def jax_sample_fn_shape_scale(op, node):
258247 name = op .name
259248 jax_op = getattr (jax .random , name )
260249
261- def sample_fn (rng , size , dtype , shape , scale ):
262- rng_key = rng ["jax_state" ]
263- rng_key , sampling_key = jax .random .split (rng_key , 2 )
250+ def sample_fn (rng_key , size , dtype , shape , scale ):
264251 if size is None :
265252 size = jax .numpy .broadcast_arrays (shape , scale )[0 ].shape
266- sample = jax_op (sampling_key , shape , size , dtype ) * scale
267- rng ["jax_state" ] = rng_key
268- return (rng , sample )
253+ sample = jax_op (rng_key , shape , size , dtype ) * scale
254+ return sample
269255
270256 return sample_fn
271257
@@ -274,14 +260,11 @@ def sample_fn(rng, size, dtype, shape, scale):
274260def jax_sample_fn_exponential (op , node ):
275261 """JAX implementation of `ExponentialRV`."""
276262
277- def sample_fn (rng , size , dtype , scale ):
278- rng_key = rng ["jax_state" ]
279- rng_key , sampling_key = jax .random .split (rng_key , 2 )
263+ def sample_fn (rng_key , size , dtype , scale ):
280264 if size is None :
281265 size = jax .numpy .asarray (scale ).shape
282- sample = jax .random .exponential (sampling_key , size , dtype ) * scale
283- rng ["jax_state" ] = rng_key
284- return (rng , sample )
266+ sample = jax .random .exponential (rng_key , size , dtype ) * scale
267+ return sample
285268
286269 return sample_fn
287270
@@ -290,14 +273,11 @@ def sample_fn(rng, size, dtype, scale):
290273def jax_sample_fn_t (op , node ):
291274 """JAX implementation of `StudentTRV`."""
292275
293- def sample_fn (rng , size , dtype , df , loc , scale ):
294- rng_key = rng ["jax_state" ]
295- rng_key , sampling_key = jax .random .split (rng_key , 2 )
276+ def sample_fn (rng_key , size , dtype , df , loc , scale ):
296277 if size is None :
297278 size = jax .numpy .broadcast_arrays (df , loc , scale )[0 ].shape
298- sample = loc + jax .random .t (sampling_key , df , size , dtype ) * scale
299- rng ["jax_state" ] = rng_key
300- return (rng , sample )
279+ sample = loc + jax .random .t (rng_key , df , size , dtype ) * scale
280+ return sample
301281
302282 return sample_fn
303283
@@ -315,10 +295,7 @@ def jax_funcify_choice(op: ptr.ChoiceWithoutReplacement, node):
315295 "A default JAX rewrite should have materialized the implicit arange"
316296 )
317297
318- def sample_fn (rng , size , dtype , * parameters ):
319- rng_key = rng ["jax_state" ]
320- rng_key , sampling_key = jax .random .split (rng_key , 2 )
321-
298+ def sample_fn (rng_key , size , dtype , * parameters ):
322299 if op .has_p_param :
323300 a , p , core_shape = parameters
324301 else :
@@ -327,9 +304,7 @@ def sample_fn(rng, size, dtype, *parameters):
327304 core_shape = tuple (np .asarray (core_shape )[(0 ,) * batch_ndim ])
328305
329306 if batch_ndim == 0 :
330- sample = jax .random .choice (
331- sampling_key , a , shape = core_shape , replace = False , p = p
332- )
307+ sample = jax .random .choice (rng_key , a , shape = core_shape , replace = False , p = p )
333308
334309 else :
335310 if size is None :
@@ -345,7 +320,7 @@ def sample_fn(rng, size, dtype, *parameters):
345320 if p is not None :
346321 p = jax .numpy .broadcast_to (p , size + p .shape [batch_ndim :])
347322
348- batch_sampling_keys = jax .random .split (sampling_key , np .prod (size ))
323+ batch_sampling_keys = jax .random .split (rng_key , np .prod (size ))
349324
350325 # Ravel the batch dimensions because vmap only works along a single axis
351326 raveled_batch_a = a .reshape ((- 1 ,) + a .shape [batch_ndim :])
@@ -366,8 +341,7 @@ def sample_fn(rng, size, dtype, *parameters):
366341 # Reshape the batch dimensions
367342 sample = raveled_sample .reshape (size + raveled_sample .shape [1 :])
368343
369- rng ["jax_state" ] = rng_key
370- return (rng , sample )
344+ return sample
371345
372346 return sample_fn
373347
@@ -378,9 +352,7 @@ def jax_sample_fn_permutation(op, node):
378352
379353 batch_ndim = op .batch_ndim (node )
380354
381- def sample_fn (rng , size , dtype , * parameters ):
382- rng_key = rng ["jax_state" ]
383- rng_key , sampling_key = jax .random .split (rng_key , 2 )
355+ def sample_fn (rng_key , size , dtype , * parameters ):
384356 (x ,) = parameters
385357 if batch_ndim :
386358 # jax.random.permutation has no concept of batch dims
@@ -389,17 +361,16 @@ def sample_fn(rng, size, dtype, *parameters):
389361 else :
390362 x = jax .numpy .broadcast_to (x , size + x .shape [batch_ndim :])
391363
392- batch_sampling_keys = jax .random .split (sampling_key , np .prod (size ))
364+ batch_sampling_keys = jax .random .split (rng_key , np .prod (size ))
393365 raveled_batch_x = x .reshape ((- 1 ,) + x .shape [batch_ndim :])
394366 raveled_sample = jax .vmap (lambda key , x : jax .random .permutation (key , x ))(
395367 batch_sampling_keys , raveled_batch_x
396368 )
397369 sample = raveled_sample .reshape (size + raveled_sample .shape [1 :])
398370 else :
399- sample = jax .random .permutation (sampling_key , x )
371+ sample = jax .random .permutation (rng_key , x )
400372
401- rng ["jax_state" ] = rng_key
402- return (rng , sample )
373+ return sample
403374
404375 return sample_fn
405376
@@ -414,15 +385,9 @@ def jax_sample_fn_binomial(op, node):
414385
415386 from numpyro .distributions .util import binomial
416387
417- def sample_fn (rng , size , dtype , n , p ):
418- rng_key = rng ["jax_state" ]
419- rng_key , sampling_key = jax .random .split (rng_key , 2 )
420-
421- sample = binomial (key = sampling_key , n = n , p = p , shape = size )
422-
423- rng ["jax_state" ] = rng_key
424-
425- return (rng , sample )
388+ def sample_fn (rng_key , size , dtype , n , p ):
389+ sample = binomial (key = rng_key , n = n , p = p , shape = size )
390+ return sample
426391
427392 return sample_fn
428393
@@ -437,15 +402,9 @@ def jax_sample_fn_multinomial(op, node):
437402
438403 from numpyro .distributions .util import multinomial
439404
440- def sample_fn (rng , size , dtype , n , p ):
441- rng_key = rng ["jax_state" ]
442- rng_key , sampling_key = jax .random .split (rng_key , 2 )
443-
444- sample = multinomial (key = sampling_key , n = n , p = p , shape = size )
445-
446- rng ["jax_state" ] = rng_key
447-
448- return (rng , sample )
405+ def sample_fn (rng_key , size , dtype , n , p ):
406+ sample = multinomial (key = rng_key , n = n , p = p , shape = size )
407+ return sample
449408
450409 return sample_fn
451410
@@ -460,17 +419,12 @@ def jax_sample_fn_vonmises(op, node):
460419
461420 from numpyro .distributions .util import von_mises_centered
462421
463- def sample_fn (rng , size , dtype , mu , kappa ):
464- rng_key = rng ["jax_state" ]
465- rng_key , sampling_key = jax .random .split (rng_key , 2 )
466-
422+ def sample_fn (rng_key , size , dtype , mu , kappa ):
467423 sample = von_mises_centered (
468- key = sampling_key , concentration = kappa , shape = size , dtype = dtype
424+ key = rng_key , concentration = kappa , shape = size , dtype = dtype
469425 )
470426 sample = (sample + mu + np .pi ) % (2.0 * np .pi ) - np .pi
471427
472- rng ["jax_state" ] = rng_key
473-
474- return (rng , sample )
428+ return sample
475429
476430 return sample_fn
0 commit comments