@@ -83,6 +83,28 @@ def test_fit_laplace_basic(mode, gradient_backend: GradientBackend):
8383 np .testing .assert_allclose (idata .fit ["covariance_matrix" ].values , bda_cov , rtol = 1e-3 , atol = 1e-3 )
8484
8585
86+ def test_fit_laplace_outside_model_context ():
87+ with pm .Model () as m :
88+ mu = pm .Normal ("mu" , 0 , 1 )
89+ sigma = pm .Exponential ("sigma" , 1 )
90+ y_hat = pm .Normal ("y_hat" , mu = mu , sigma = sigma , observed = np .random .normal (size = 10 ))
91+
92+ idata = fit_laplace (
93+ model = m ,
94+ optimize_method = "L-BFGS-B" ,
95+ use_grad = True ,
96+ progressbar = False ,
97+ chains = 1 ,
98+ draws = 100 ,
99+ )
100+
101+ assert hasattr (idata , "posterior" )
102+ assert hasattr (idata , "fit" )
103+ assert hasattr (idata , "optimizer_result" )
104+
105+ assert idata .posterior ["mu" ].shape == (1 , 100 )
106+
107+
86108@pytest .mark .parametrize (
87109 "include_transformed" , [True , False ], ids = ["include_transformed" , "no_transformed" ]
88110)
@@ -208,6 +230,50 @@ def test_model_with_nonstandard_dimensionality(rng):
208230 assert "class" in list (idata .unconstrained_posterior .sigma_log__ .coords .keys ())
209231
210232
233+ def test_laplace_nonstandard_dims_2d ():
234+ true_P = np .array ([[0.5 , 0.3 , 0.2 ], [0.1 , 0.6 , 0.3 ], [0.2 , 0.4 , 0.4 ]])
235+ y_obs = pm .draw (
236+ pmx .DiscreteMarkovChain .dist (
237+ P = true_P ,
238+ init_dist = pm .Categorical .dist (
239+ logit_p = np .ones (
240+ 3 ,
241+ )
242+ ),
243+ shape = (100 , 5 ),
244+ )
245+ )
246+
247+ with pm .Model (
248+ coords = {
249+ "time" : range (y_obs .shape [0 ]),
250+ "state" : list ("ABC" ),
251+ "next_state" : list ("ABC" ),
252+ "unit" : [1 , 2 , 3 , 4 , 5 ],
253+ }
254+ ) as model :
255+ y = pm .Data ("y" , y_obs , dims = ["time" , "unit" ])
256+ init_dist = pm .Categorical .dist (
257+ logit_p = np .ones (
258+ 3 ,
259+ )
260+ )
261+ P = pm .Dirichlet ("P" , a = np .eye (3 ) * 2 + 1 , dims = ["state" , "next_state" ])
262+ y_hat = pmx .DiscreteMarkovChain (
263+ "y_hat" , P = P , init_dist = init_dist , dims = ["time" , "unit" ], observed = y_obs
264+ )
265+
266+ idata = pmx .fit_laplace (progressbar = True )
267+
268+ # The simplex transform should drop from the right-most dimension, so the left dimension should be unmodified
269+ assert "state" in list (idata .unconstrained_posterior .P_simplex__ .coords .keys ())
270+
271+ # The mutated dimension should be unknown coords
272+ assert "P_simplex___dim_1" in list (idata .unconstrained_posterior .P_simplex__ .coords .keys ())
273+
274+ assert idata .unconstrained_posterior .P_simplex__ .shape [- 2 :] == (3 , 2 )
275+
276+
211277def test_laplace_nonscalar_rv_without_dims ():
212278 with pm .Model (coords = {"test" : ["A" , "B" , "C" ]}) as model :
213279 x_loc = pm .Normal ("x_loc" , mu = 0 , sigma = 1 , dims = ["test" ])
0 commit comments