@@ -133,12 +133,19 @@ def compute_z(x):
133133 ],
134134)
135135@pytest .mark .parametrize (
136- "backend, gradient_backend" ,
137- [("jax" , "jax" ), ("jax" , "pytensor" )],
136+ "backend, gradient_backend, include_transformed " ,
137+ [("jax" , "jax" , True ), ("jax" , "pytensor" , False )],
138138 ids = str ,
139139)
140140def test_find_MAP (
141- method , use_grad , use_hess , use_hessp , backend , gradient_backend : GradientBackend , rng
141+ method ,
142+ use_grad ,
143+ use_hess ,
144+ use_hessp ,
145+ backend ,
146+ gradient_backend : GradientBackend ,
147+ include_transformed ,
148+ rng ,
142149):
143150 pytest .importorskip ("jax" )
144151
@@ -154,12 +161,12 @@ def test_find_MAP(
154161 use_hessp = use_hessp ,
155162 progressbar = False ,
156163 gradient_backend = gradient_backend ,
164+ include_transformed = include_transformed ,
157165 compile_kwargs = {"mode" : backend .upper ()},
158166 maxiter = 5 ,
159167 )
160168
161169 assert hasattr (idata , "posterior" )
162- assert hasattr (idata , "unconstrained_posterior" )
163170 assert hasattr (idata , "fit" )
164171 assert hasattr (idata , "optimizer_result" )
165172 assert hasattr (idata , "observed_data" )
@@ -169,9 +176,13 @@ def test_find_MAP(
169176 assert posterior ["mu" ].shape == ()
170177 assert posterior ["sigma" ].shape == ()
171178
172- unconstrained_posterior = idata .unconstrained_posterior .squeeze (["chain" , "draw" ])
173- assert "sigma_log__" in unconstrained_posterior
174- assert unconstrained_posterior ["sigma_log__" ].shape == ()
179+ if include_transformed :
180+ assert hasattr (idata , "unconstrained_posterior" )
181+ unconstrained_posterior = idata .unconstrained_posterior .squeeze (["chain" , "draw" ])
182+ assert "sigma_log__" in unconstrained_posterior
183+ assert unconstrained_posterior ["sigma_log__" ].shape == ()
184+ else :
185+ assert not hasattr (idata , "unconstrained_posterior" )
175186
176187
177188@pytest .mark .parametrize (
0 commit comments