@@ -63,6 +63,43 @@ def test_random_updates(rng_ctor):
6363 )
6464
6565
66+ def test_random_updates_input_storage_order ():
67+ """Test case described in issue #314.
68+
69+ This happened when we tried to update the input storage after we clone the shared RNG.
70+ We used to call `input_storage.index(old_input_storage)` which would fail when the input_storage contained
71+ numpy arrays before the RNG value, which would fail the equality check.
72+
73+ """
74+ pt_rng = RandomStream (1 )
75+
76+ batchshape = (3 , 1 , 4 , 4 )
77+ inp_shared = pytensor .shared (
78+ np .zeros (batchshape , dtype = "float64" ), name = "inp_shared"
79+ )
80+
81+ inp = at .tensor4 (dtype = "float64" , name = "inp" )
82+ inp_update = inp + pt_rng .normal (size = inp .shape , loc = 5 , scale = 1e-5 )
83+
84+ # This function replaces inp by input_shared in the update expression
85+ # This is what caused the RNG to appear later than inp_shared in the input_storage
86+ with pytest .warns (
87+ UserWarning ,
88+ match = r"The RandomType SharedVariables \[.+\] will not be used" ,
89+ ):
90+ fn = pytensor .function (
91+ inputs = [],
92+ outputs = [],
93+ updates = {inp_shared : inp_update },
94+ givens = {inp : inp_shared },
95+ mode = "JAX" ,
96+ )
97+ fn ()
98+ np .testing .assert_allclose (inp_shared .get_value (), 5 , rtol = 1e-3 )
99+ fn ()
100+ np .testing .assert_allclose (inp_shared .get_value (), 10 , rtol = 1e-3 )
101+
102+
66103@pytest .mark .parametrize (
67104 "rv_op, dist_params, base_size, cdf_name, params_conv" ,
68105 [
0 commit comments