@@ -133,7 +133,7 @@ def test_blackjax_particles_from_pymc_population_univariate():
133133 model = fast_model ()
134134 population = {"x" : np .array ([2 , 3 , 4 ])}
135135 blackjax_particles = blackjax_particles_from_pymc_population (model , population )
136- jax .tree_map (np .testing .assert_allclose , blackjax_particles , [np .array ([[2 ], [3 ], [4 ]])])
136+ jax .tree . map (np .testing .assert_allclose , blackjax_particles , [np .array ([[2 ], [3 ], [4 ]])])
137137
138138
139139def test_blackjax_particles_from_pymc_population_multivariate ():
@@ -144,7 +144,7 @@ def test_blackjax_particles_from_pymc_population_multivariate():
144144
145145 population = {"x" : np .array ([0.34614613 , 1.09163261 , - 0.44526825 ]), "z" : np .array ([1 , 2 , 3 ])}
146146 blackjax_particles = blackjax_particles_from_pymc_population (model , population )
147- jax .tree_map (
147+ jax .tree . map (
148148 np .testing .assert_allclose ,
149149 blackjax_particles ,
150150 [np .array ([[0.34614613 ], [1.09163261 ], [- 0.44526825 ]]), np .array ([[1 ], [2 ], [3 ]])],
@@ -168,7 +168,7 @@ def test_blackjax_particles_from_pymc_population_multivariable():
168168 population = {"x" : np .array ([[2 , 3 ], [5 , 6 ], [7 , 9 ]]), "z" : np .array ([11 , 12 , 13 ])}
169169 blackjax_particles = blackjax_particles_from_pymc_population (model , population )
170170
171- jax .tree_map (
171+ jax .tree . map (
172172 np .testing .assert_allclose ,
173173 blackjax_particles ,
174174 [np .array ([[2 , 3 ], [5 , 6 ], [7 , 9 ]]), np .array ([[11 ], [12 ], [13 ]])],
@@ -196,7 +196,7 @@ def test_get_jaxified_logprior():
196196 """
197197 logprior = get_jaxified_logprior (fast_model ())
198198 for point in [- 0.5 , 0.0 , 0.5 ]:
199- jax .tree_map (
199+ jax .tree . map (
200200 np .testing .assert_allclose ,
201201 jax .vmap (logprior )([np .array ([point ])]),
202202 np .log (scipy .stats .norm (0 , 1 ).pdf (point )),
@@ -212,7 +212,7 @@ def test_get_jaxified_loglikelihood():
212212 """
213213 loglikelihood = get_jaxified_loglikelihood (fast_model ())
214214 for point in [- 0.5 , 0.0 , 0.5 ]:
215- jax .tree_map (
215+ jax .tree . map (
216216 np .testing .assert_allclose ,
217217 jax .vmap (loglikelihood )([np .array ([point ])]),
218218 np .log (scipy .stats .norm (point , 1 ).pdf (0 )),
0 commit comments