2626from pytensor .link .jax .dispatch .random import numpyro_available # noqa: E402
2727
2828
29- def compile_random_function (* args , mode = "JAX" , ** kwargs ):
29+ def compile_random_function (* args , mode = jax_mode , ** kwargs ):
3030 with pytest .warns (
3131 UserWarning , match = r"The RandomType SharedVariables \[.+\] will not be used"
3232 ):
@@ -41,7 +41,7 @@ def test_random_RandomStream():
4141 srng = RandomStream (seed = 123 )
4242 out = srng .normal () - srng .normal ()
4343
44- fn = compile_random_function ([], out , mode = jax_mode )
44+ fn = compile_random_function ([], out )
4545 jax_res_1 = fn ()
4646 jax_res_2 = fn ()
4747
@@ -54,7 +54,7 @@ def test_random_updates(rng_ctor):
5454 rng = shared (original_value , name = "original_rng" , borrow = False )
5555 next_rng , x = pt .random .normal (name = "x" , rng = rng ).owner .outputs
5656
57- f = compile_random_function ([], [x ], updates = {rng : next_rng }, mode = jax_mode )
57+ f = compile_random_function ([], [x ], updates = {rng : next_rng })
5858 assert f () != f ()
5959
6060 # Check that original rng variable content was not overwritten when calling jax_typify
@@ -482,7 +482,7 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c
482482 )
483483 rng = shared (np .random .default_rng (29403 ))
484484 g = rv_op (* dist_params , size = (10000 , * base_size ), rng = rng )
485- g_fn = compile_random_function (dist_params , g , mode = jax_mode )
485+ g_fn = compile_random_function (dist_params , g )
486486 samples = g_fn (* test_values )
487487
488488 bcast_dist_args = np .broadcast_arrays (* test_values )
@@ -518,7 +518,7 @@ def test_size_implied_by_broadcasted_parameters(rv_fn):
518518 param_that_implies_size = pt .matrix ("param_that_implies_size" , shape = (None , None ))
519519
520520 rv = rv_fn (param_that_implies_size )
521- draws = rv .eval ({param_that_implies_size : np .zeros ((2 , 2 ))}, mode = jax_mode )
521+ draws = rv .eval ({param_that_implies_size : np .zeros ((2 , 2 ))})
522522
523523 assert draws .shape == (2 , 2 )
524524 assert np .unique (draws ).size == 4
@@ -528,7 +528,7 @@ def test_size_implied_by_broadcasted_parameters(rv_fn):
528528def test_random_bernoulli (size ):
529529 rng = shared (np .random .default_rng (123 ))
530530 g = pt .random .bernoulli (0.5 , size = (1000 , * size ), rng = rng )
531- g_fn = compile_random_function ([], g , mode = jax_mode )
531+ g_fn = compile_random_function ([], g )
532532 samples = g_fn ()
533533 np .testing .assert_allclose (samples .mean (axis = 0 ), 0.5 , 1 )
534534
@@ -539,7 +539,7 @@ def test_random_mvnormal():
539539 mu = np .ones (4 )
540540 cov = np .eye (4 )
541541 g = pt .random .multivariate_normal (mu , cov , size = (10000 ,), rng = rng )
542- g_fn = compile_random_function ([], g , mode = jax_mode )
542+ g_fn = compile_random_function ([], g )
543543 samples = g_fn ()
544544 np .testing .assert_allclose (samples .mean (axis = 0 ), mu , atol = 0.1 )
545545
@@ -559,7 +559,7 @@ def test_random_mvnormal():
559559def test_random_dirichlet (parameter , size ):
560560 rng = shared (np .random .default_rng (123 ))
561561 g = pt .random .dirichlet (parameter , size = (1000 , * size ), rng = rng )
562- g_fn = compile_random_function ([], g , mode = jax_mode )
562+ g_fn = compile_random_function ([], g )
563563 samples = g_fn ()
564564 np .testing .assert_allclose (samples .mean (axis = 0 ), 0.5 , 1 )
565565
@@ -568,7 +568,7 @@ def test_random_choice():
568568 # `replace=True` and `p is None`
569569 rng = shared (np .random .default_rng (123 ))
570570 g = pt .random .choice (np .arange (4 ), size = 10_000 , rng = rng )
571- g_fn = compile_random_function ([], g , mode = jax_mode )
571+ g_fn = compile_random_function ([], g )
572572 samples = g_fn ()
573573 assert samples .shape == (10_000 ,)
574574 # Elements are picked at equal frequency
@@ -577,7 +577,7 @@ def test_random_choice():
577577 # `replace=True` and `p is not None`
578578 rng = shared (np .random .default_rng (123 ))
579579 g = pt .random .choice (4 , p = np .array ([0.0 , 0.5 , 0.0 , 0.5 ]), size = (5 , 2 ), rng = rng )
580- g_fn = compile_random_function ([], g , mode = jax_mode )
580+ g_fn = compile_random_function ([], g )
581581 samples = g_fn ()
582582 assert samples .shape == (5 , 2 )
583583 # Only odd numbers are picked
@@ -586,7 +586,7 @@ def test_random_choice():
586586 # `replace=False` and `p is None`
587587 rng = shared (np .random .default_rng (123 ))
588588 g = pt .random .choice (np .arange (100 ), replace = False , size = (2 , 49 ), rng = rng )
589- g_fn = compile_random_function ([], g , mode = jax_mode )
589+ g_fn = compile_random_function ([], g )
590590 samples = g_fn ()
591591 assert samples .shape == (2 , 49 )
592592 # Elements are unique
@@ -601,7 +601,7 @@ def test_random_choice():
601601 rng = rng ,
602602 replace = False ,
603603 )
604- g_fn = compile_random_function ([], g , mode = jax_mode )
604+ g_fn = compile_random_function ([], g )
605605 samples = g_fn ()
606606 assert samples .shape == (3 ,)
607607 # Elements are unique
@@ -613,14 +613,14 @@ def test_random_choice():
613613def test_random_categorical ():
614614 rng = shared (np .random .default_rng (123 ))
615615 g = pt .random .categorical (0.25 * np .ones (4 ), size = (10000 , 4 ), rng = rng )
616- g_fn = compile_random_function ([], g , mode = jax_mode )
616+ g_fn = compile_random_function ([], g )
617617 samples = g_fn ()
618618 assert samples .shape == (10000 , 4 )
619619 np .testing .assert_allclose (samples .mean (axis = 0 ), 6 / 4 , 1 )
620620
621621 # Test zero probabilities
622622 g = pt .random .categorical ([0 , 0.5 , 0 , 0.5 ], size = (1000 ,), rng = rng )
623- g_fn = compile_random_function ([], g , mode = jax_mode )
623+ g_fn = compile_random_function ([], g )
624624 samples = g_fn ()
625625 assert samples .shape == (1000 ,)
626626 assert np .all (samples % 2 == 1 )
@@ -630,7 +630,7 @@ def test_random_permutation():
630630 array = np .arange (4 )
631631 rng = shared (np .random .default_rng (123 ))
632632 g = pt .random .permutation (array , rng = rng )
633- g_fn = compile_random_function ([], g , mode = jax_mode )
633+ g_fn = compile_random_function ([], g )
634634 permuted = g_fn ()
635635 with pytest .raises (AssertionError ):
636636 np .testing .assert_allclose (array , permuted )
@@ -653,7 +653,7 @@ def test_random_geometric():
653653 rng = shared (np .random .default_rng (123 ))
654654 p = np .array ([0.3 , 0.7 ])
655655 g = pt .random .geometric (p , size = (10_000 , 2 ), rng = rng )
656- g_fn = compile_random_function ([], g , mode = jax_mode )
656+ g_fn = compile_random_function ([], g )
657657 samples = g_fn ()
658658 np .testing .assert_allclose (samples .mean (axis = 0 ), 1 / p , rtol = 0.1 )
659659 np .testing .assert_allclose (samples .std (axis = 0 ), np .sqrt ((1 - p ) / p ** 2 ), rtol = 0.1 )
@@ -664,7 +664,7 @@ def test_negative_binomial():
664664 n = np .array ([10 , 40 ])
665665 p = np .array ([0.3 , 0.7 ])
666666 g = pt .random .negative_binomial (n , p , size = (10_000 , 2 ), rng = rng )
667- g_fn = compile_random_function ([], g , mode = jax_mode )
667+ g_fn = compile_random_function ([], g )
668668 samples = g_fn ()
669669 np .testing .assert_allclose (samples .mean (axis = 0 ), n * (1 - p ) / p , rtol = 0.1 )
670670 np .testing .assert_allclose (
@@ -678,7 +678,7 @@ def test_binomial():
678678 n = np .array ([10 , 40 ])
679679 p = np .array ([0.3 , 0.7 ])
680680 g = pt .random .binomial (n , p , size = (10_000 , 2 ), rng = rng )
681- g_fn = compile_random_function ([], g , mode = jax_mode )
681+ g_fn = compile_random_function ([], g )
682682 samples = g_fn ()
683683 np .testing .assert_allclose (samples .mean (axis = 0 ), n * p , rtol = 0.1 )
684684 np .testing .assert_allclose (samples .std (axis = 0 ), np .sqrt (n * p * (1 - p )), rtol = 0.1 )
@@ -693,7 +693,7 @@ def test_beta_binomial():
693693 a = np .array ([1.5 , 13 ])
694694 b = np .array ([0.5 , 9 ])
695695 g = pt .random .betabinom (n , a , b , size = (10_000 , 2 ), rng = rng )
696- g_fn = compile_random_function ([], g , mode = jax_mode )
696+ g_fn = compile_random_function ([], g )
697697 samples = g_fn ()
698698 np .testing .assert_allclose (samples .mean (axis = 0 ), n * a / (a + b ), rtol = 0.1 )
699699 np .testing .assert_allclose (
@@ -754,7 +754,7 @@ def test_vonmises_mu_outside_circle():
754754 mu = np .array ([- 30 , 40 ])
755755 kappa = np .array ([100 , 10 ])
756756 g = pt .random .vonmises (mu , kappa , size = (10_000 , 2 ), rng = rng )
757- g_fn = compile_random_function ([], g , mode = jax_mode )
757+ g_fn = compile_random_function ([], g )
758758 samples = g_fn ()
759759 np .testing .assert_allclose (
760760 samples .mean (axis = 0 ), (mu + np .pi ) % (2.0 * np .pi ) - np .pi , rtol = 0.1
@@ -850,15 +850,15 @@ def test_random_concrete_shape():
850850 rng = shared (np .random .default_rng (123 ))
851851 x_pt = pt .dmatrix ()
852852 out = pt .random .normal (0 , 1 , size = x_pt .shape , rng = rng )
853- jax_fn = compile_random_function ([x_pt ], out , mode = jax_mode )
853+ jax_fn = compile_random_function ([x_pt ], out )
854854 assert jax_fn (np .ones ((2 , 3 ))).shape == (2 , 3 )
855855
856856
857857def test_random_concrete_shape_from_param ():
858858 rng = shared (np .random .default_rng (123 ))
859859 x_pt = pt .dmatrix ()
860860 out = pt .random .normal (x_pt , 1 , rng = rng )
861- jax_fn = compile_random_function ([x_pt ], out , mode = jax_mode )
861+ jax_fn = compile_random_function ([x_pt ], out )
862862 assert jax_fn (np .ones ((2 , 3 ))).shape == (2 , 3 )
863863
864864
@@ -877,7 +877,7 @@ def test_random_concrete_shape_subtensor():
877877 rng = shared (np .random .default_rng (123 ))
878878 x_pt = pt .dmatrix ()
879879 out = pt .random .normal (0 , 1 , size = x_pt .shape [1 ], rng = rng )
880- jax_fn = compile_random_function ([x_pt ], out , mode = jax_mode )
880+ jax_fn = compile_random_function ([x_pt ], out )
881881 assert jax_fn (np .ones ((2 , 3 ))).shape == (3 ,)
882882
883883
@@ -893,7 +893,7 @@ def test_random_concrete_shape_subtensor_tuple():
893893 rng = shared (np .random .default_rng (123 ))
894894 x_pt = pt .dmatrix ()
895895 out = pt .random .normal (0 , 1 , size = (x_pt .shape [0 ],), rng = rng )
896- jax_fn = compile_random_function ([x_pt ], out , mode = jax_mode )
896+ jax_fn = compile_random_function ([x_pt ], out )
897897 assert jax_fn (np .ones ((2 , 3 ))).shape == (2 ,)
898898
899899
@@ -904,7 +904,7 @@ def test_random_concrete_shape_graph_input():
904904 rng = shared (np .random .default_rng (123 ))
905905 size_pt = pt .scalar ()
906906 out = pt .random .normal (0 , 1 , size = size_pt , rng = rng )
907- jax_fn = compile_random_function ([size_pt ], out , mode = jax_mode )
907+ jax_fn = compile_random_function ([size_pt ], out )
908908 assert jax_fn (10 ).shape == (10 ,)
909909
910910
0 commit comments