99from arviz import InferenceData , dict_to_dataset
1010from pymc .backends .arviz import coords_and_dims_for_inferencedata , dataset_to_point_list
1111from pymc .distributions .discrete import Bernoulli , Categorical , DiscreteUniform
12+ from pymc .distributions .multivariate import MvNormal
1213from pymc .distributions .transforms import Chain
1314from pymc .logprob .transforms import IntervalTransform
1415from pymc .model import Model
4546from pymc_extras .model .marginal .distributions import (
4647 MarginalDiscreteMarkovChainRV ,
4748 MarginalFiniteDiscreteRV ,
49+ MarginalLaplaceRV ,
4850 MarginalRV ,
4951 NonSeparableLogpWarning ,
5052 get_domain_of_finite_discrete_rv ,
@@ -144,7 +146,9 @@ def _unique(seq: Sequence) -> list:
144146 return [x for x in seq if not (x in seen or seen_add (x ))]
145147
146148
147- def marginalize (model : Model , rvs_to_marginalize : ModelRVs ) -> MarginalModel :
149+ def marginalize (
150+ model : Model , rvs_to_marginalize : ModelRVs , use_laplace : bool = False , ** marginalize_kwargs
151+ ) -> MarginalModel :
148152 """Marginalize a subset of variables in a PyMC model.
149153
150154 This creates a class of `MarginalModel` from an existing `Model`, with the specified
@@ -158,6 +162,8 @@ def marginalize(model: Model, rvs_to_marginalize: ModelRVs) -> MarginalModel:
158162 PyMC model to marginalize. Original variables well be cloned.
159163 rvs_to_marginalize : Sequence[TensorVariable]
160164 Variables to marginalize in the returned model.
165+ use_laplace : bool
166+ Whether to use Laplace appoximations to marginalize out rvs_to_marginalize.
161167
162168 Returns
163169 -------
@@ -186,7 +192,12 @@ def marginalize(model: Model, rvs_to_marginalize: ModelRVs) -> MarginalModel:
186192 raise NotImplementedError (
187193 "Marginalization for DiscreteMarkovChain with non-matrix transition probability is not supported"
188194 )
189- elif not isinstance (rv_op , Bernoulli | Categorical | DiscreteUniform ):
195+ elif use_laplace and not isinstance (rv_op , MvNormal ):
196+ raise ValueError (
197+ f"Marginalisation method set to Laplace but RV { rv_to_marginalize } is not instance of MvNormal. Has distribution { rv_to_marginalize .owner .op } "
198+ )
199+
200+ elif not use_laplace and not isinstance (rv_op , Bernoulli | Categorical | DiscreteUniform ):
190201 raise NotImplementedError (
191202 f"Marginalization of RV with distribution { rv_to_marginalize .owner .op } is not supported"
192203 )
@@ -241,7 +252,9 @@ def marginalize(model: Model, rvs_to_marginalize: ModelRVs) -> MarginalModel:
241252 ]
242253 input_rvs = _unique ((* marginalized_rv_input_rvs , * other_direct_rv_ancestors ))
243254
244- replace_finite_discrete_marginal_subgraph (fg , rv_to_marginalize , dependent_rvs , input_rvs )
255+ replace_marginal_subgraph (
256+ fg , rv_to_marginalize , dependent_rvs , input_rvs , use_laplace , ** marginalize_kwargs
257+ )
245258
246259 return model_from_fgraph (fg , mutate_fgraph = True )
247260
@@ -551,22 +564,32 @@ def remove_model_vars(vars):
551564 return fgraph .outputs
552565
553566
554- def replace_finite_discrete_marginal_subgraph (
555- fgraph , rv_to_marginalize , dependent_rvs , input_rvs
567+ def replace_marginal_subgraph (
568+ fgraph ,
569+ rv_to_marginalize ,
570+ dependent_rvs ,
571+ input_rvs ,
572+ use_laplace = False ,
573+ ** marginalize_kwargs ,
556574) -> None :
557575 # If the marginalized RV has multiple dimensions, check that graph between
558576 # marginalized RV and dependent RVs does not mix information from batch dimensions
559577 # (otherwise logp would require enumerating over all combinations of batch dimension values)
560- try :
561- dependent_rvs_dim_connections = subgraph_batch_dim_connection (
562- rv_to_marginalize , dependent_rvs
563- )
564- except (ValueError , NotImplementedError ) as e :
565- # For the perspective of the user this is a NotImplementedError
566- raise NotImplementedError (
567- "The graph between the marginalized and dependent RVs cannot be marginalized efficiently. "
568- "You can try splitting the marginalized RV into separate components and marginalizing them separately."
569- ) from e
578+ if not use_laplace :
579+ try :
580+ dependent_rvs_dim_connections = subgraph_batch_dim_connection (
581+ rv_to_marginalize , dependent_rvs
582+ )
583+ except (ValueError , NotImplementedError ) as e :
584+ # For the perspective of the user this is a NotImplementedError
585+ raise NotImplementedError (
586+ "The graph between the marginalized and dependent RVs cannot be marginalized efficiently. "
587+ "You can try splitting the marginalized RV into separate components and marginalizing them separately."
588+ ) from e
589+ else :
590+ dependent_rvs_dim_connections = [
591+ (None ,),
592+ ]
570593
571594 output_rvs = [rv_to_marginalize , * dependent_rvs ]
572595 rng_updates = collect_default_updates (output_rvs , inputs = input_rvs , must_be_shared = False )
@@ -581,6 +604,8 @@ def replace_finite_discrete_marginal_subgraph(
581604
582605 if isinstance (inner_outputs [0 ].owner .op , DiscreteMarkovChain ):
583606 marginalize_constructor = MarginalDiscreteMarkovChainRV
607+ elif use_laplace :
608+ marginalize_constructor = MarginalLaplaceRV
584609 else :
585610 marginalize_constructor = MarginalFiniteDiscreteRV
586611
@@ -590,6 +615,7 @@ def replace_finite_discrete_marginal_subgraph(
590615 outputs = inner_outputs ,
591616 dims_connections = dependent_rvs_dim_connections ,
592617 dims = dims ,
618+ ** marginalize_kwargs ,
593619 )
594620
595621 new_outputs = marginalization_op (* inputs )
0 commit comments