From 52d8a92eb9a0af436bee77ea46a41e3640cc2189 Mon Sep 17 00:00:00 2001 From: Jonathan Dekermanjian Date: Fri, 7 Nov 2025 07:20:44 -0700 Subject: [PATCH 1/3] updated dim_shape assignment logic in fit_laplace to handle absent dims on data containers and deterministics --- .../inference/laplace_approx/laplace.py | 4 +- .../inference/laplace_approx/test_laplace.py | 38 +++++++++++++++++++ 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/pymc_extras/inference/laplace_approx/laplace.py b/pymc_extras/inference/laplace_approx/laplace.py index cd248dcf..72247555 100644 --- a/pymc_extras/inference/laplace_approx/laplace.py +++ b/pymc_extras/inference/laplace_approx/laplace.py @@ -226,9 +226,9 @@ def model_to_laplace_approx( else: dims = (*batch_dims, *[f"{name}_dim_{i}" for i in range(batched_rv.ndim - 2)]) initval = initial_point.get(name, None) - dim_shapes = initval.shape if initval is not None else batched_rv.type.shape[2:] + dim_shapes = initval.shape if initval is not None else batched_rv.shape.eval()[2:] laplace_model.add_coords( - {name: np.arange(shape) for name, shape in zip(dims[2:], dim_shapes)} + {name: pt.arange(shape) for name, shape in zip(dims[2:], dim_shapes)} ) pm.Deterministic(name, batched_rv, dims=dims) diff --git a/tests/inference/laplace_approx/test_laplace.py b/tests/inference/laplace_approx/test_laplace.py index f02f296a..d9edb958 100644 --- a/tests/inference/laplace_approx/test_laplace.py +++ b/tests/inference/laplace_approx/test_laplace.py @@ -193,6 +193,44 @@ def test_fit_laplace_ragged_coords(rng): assert (idata["posterior"].beta.sel(feature=1).to_numpy() > 0).all() +def test_fit_laplace_no_data_or_deterministic_dims(rng): + coords = {"city": ["A", "B", "C"], "feature": [0, 1], "obs_idx": np.arange(100)} + with pm.Model(coords=coords) as ragged_dim_model: + X = pm.Data("X", np.ones((100, 2))) + beta = pm.Normal( + "beta", mu=[[-100.0, 100.0], [-100.0, 100.0], [-100.0, 100.0]], dims=["city", "feature"] + ) + mu = pm.Deterministic("mu", (X[:, None, :] * beta[None]).sum(axis=-1)) + sigma = pm.Normal("sigma", mu=1.5, sigma=0.5, dims=["city"]) + + obs = pm.Normal( + "obs", + mu=mu, + sigma=sigma, + observed=rng.normal(loc=3, scale=1.5, size=(100, 3)), + dims=["obs_idx", "city"], + ) + + idata = fit_laplace( + optimize_method="Newton-CG", + progressbar=False, + use_grad=True, + use_hessp=True, + ) + + # These should have been dropped when the laplace idata was created + assert "laplace_approximation" not in list(idata.posterior.data_vars.keys()) + assert "unpacked_var_names" not in list(idata.posterior.coords.keys()) + + assert idata["posterior"].beta.shape[-2:] == (3, 2) + assert idata["posterior"].sigma.shape[-1:] == (3,) + + # Check that everything got unraveled correctly -- feature 0 should be strictly negative, feature 1 + # strictly positive + assert (idata["posterior"].beta.sel(feature=0).to_numpy() < 0).all() + assert (idata["posterior"].beta.sel(feature=1).to_numpy() > 0).all() + + def test_model_with_nonstandard_dimensionality(rng): y_obs = np.concatenate( [rng.normal(-1, 2, size=150), rng.normal(3, 1, size=350), rng.normal(5, 4, size=50)] From 1e417582deb7325cb21c8537ef9208ad60a88722 Mon Sep 17 00:00:00 2001 From: Jonathan Dekermanjian Date: Sat, 8 Nov 2025 05:36:13 -0700 Subject: [PATCH 2/3] reverted to numpy arange --- pymc_extras/inference/laplace_approx/laplace.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc_extras/inference/laplace_approx/laplace.py b/pymc_extras/inference/laplace_approx/laplace.py index 72247555..7ec838bc 100644 --- a/pymc_extras/inference/laplace_approx/laplace.py +++ b/pymc_extras/inference/laplace_approx/laplace.py @@ -228,7 +228,7 @@ def model_to_laplace_approx( initval = initial_point.get(name, None) dim_shapes = initval.shape if initval is not None else batched_rv.shape.eval()[2:] laplace_model.add_coords( - {name: pt.arange(shape) for name, shape in zip(dims[2:], dim_shapes)} + {name: np.arange(shape) for name, shape in zip(dims[2:], dim_shapes)} ) pm.Deterministic(name, batched_rv, dims=dims) From 5abe4ff11221531d82e7f9db879bb18acf9e3249 Mon Sep 17 00:00:00 2001 From: Jonathan Dekermanjian Date: Sat, 8 Nov 2025 08:40:32 -0700 Subject: [PATCH 3/3] updated dim handling in model_to_laplace_approx to not force dims on variables that did not have them originally --- pymc_extras/inference/laplace_approx/laplace.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/pymc_extras/inference/laplace_approx/laplace.py b/pymc_extras/inference/laplace_approx/laplace.py index 7ec838bc..d78fc3df 100644 --- a/pymc_extras/inference/laplace_approx/laplace.py +++ b/pymc_extras/inference/laplace_approx/laplace.py @@ -224,12 +224,15 @@ def model_to_laplace_approx( elif name in model.named_vars_to_dims: dims = (*batch_dims, *model.named_vars_to_dims[name]) else: - dims = (*batch_dims, *[f"{name}_dim_{i}" for i in range(batched_rv.ndim - 2)]) initval = initial_point.get(name, None) - dim_shapes = initval.shape if initval is not None else batched_rv.shape.eval()[2:] - laplace_model.add_coords( - {name: np.arange(shape) for name, shape in zip(dims[2:], dim_shapes)} - ) + dim_shapes = initval.shape if initval is not None else batched_rv.type.shape[2:] + if dim_shapes[0] is not None: + dims = (*batch_dims, *[f"{name}_dim_{i}" for i in range(batched_rv.ndim - 2)]) + laplace_model.add_coords( + {name: np.arange(shape) for name, shape in zip(dims[2:], dim_shapes)} + ) + else: + dims = None pm.Deterministic(name, batched_rv, dims=dims)