1010from pymc .distributions .discrete import Bernoulli , Categorical , DiscreteUniform
1111from pymc .distributions .transforms import Chain
1212from pymc .logprob .abstract import _logprob
13- from pymc .logprob .basic import conditional_logp
13+ from pymc .logprob .basic import conditional_logp , logp
1414from pymc .logprob .transforms import IntervalTransform
1515from pymc .model import Model
1616from pymc .pytensorf import compile_pymc , constant_fold , inputvars
1717from pymc .util import _get_seeds_per_chain , dataset_to_point_list , treedict
18- from pytensor import Mode
18+ from pytensor import Mode , scan
1919from pytensor .compile import SharedVariable
2020from pytensor .compile .builders import OpFromGraph
21- from pytensor .graph import (
22- Constant ,
23- FunctionGraph ,
24- ancestors ,
25- clone_replace ,
26- vectorize_graph ,
27- )
21+ from pytensor .graph import Constant , FunctionGraph , ancestors , clone_replace
22+ from pytensor .graph .replace import vectorize_graph
2823from pytensor .scan import map as scan_map
2924from pytensor .tensor import TensorType , TensorVariable
3025from pytensor .tensor .elemwise import Elemwise
3328
3429__all__ = ["MarginalModel" ]
3530
31+ from pymc_experimental .distributions import DiscreteMarkovChain
32+
3633
3734class MarginalModel (Model ):
3835 """Subclass of PyMC Model that implements functionality for automatic
@@ -247,16 +244,25 @@ def marginalize(
247244 self [var ] if isinstance (var , str ) else var for var in rvs_to_marginalize
248245 ]
249246
250- supported_dists = (Bernoulli , Categorical , DiscreteUniform )
251247 for rv_to_marginalize in rvs_to_marginalize :
252248 if rv_to_marginalize not in self .free_RVs :
253249 raise ValueError (
254250 f"Marginalized RV { rv_to_marginalize } is not a free RV in the model"
255251 )
256- if not isinstance (rv_to_marginalize .owner .op , supported_dists ):
252+
253+ rv_op = rv_to_marginalize .owner .op
254+ if isinstance (rv_op , DiscreteMarkovChain ):
255+ if rv_op .n_lags > 1 :
256+ raise NotImplementedError (
257+ "Marginalization for DiscreteMarkovChain with n_lags > 1 is not supported"
258+ )
259+ if rv_to_marginalize .owner .inputs [0 ].type .ndim > 2 :
260+ raise NotImplementedError (
261+ "Marginalization for DiscreteMarkovChain with non-matrix transition probability is not supported"
262+ )
263+ elif not isinstance (rv_op , (Bernoulli , Categorical , DiscreteUniform )):
257264 raise NotImplementedError (
258- f"RV with distribution { rv_to_marginalize .owner .op } cannot be marginalized. "
259- f"Supported distribution include { supported_dists } "
265+ f"Marginalization of RV with distribution { rv_to_marginalize .owner .op } is not supported"
260266 )
261267
262268 if rv_to_marginalize .name in self .named_vars_to_dims :
@@ -492,6 +498,10 @@ class FiniteDiscreteMarginalRV(MarginalRV):
492498 """Base class for Finite Discrete Marginalized RVs"""
493499
494500
501+ class DiscreteMarginalMarkovChainRV (MarginalRV ):
502+ """Base class for Discrete Marginal Markov Chain RVs"""
503+
504+
495505def static_shape_ancestors (vars ):
496506 """Identify ancestors Shape Ops of static shapes (therefore constant in a valid graph)."""
497507 return [
@@ -620,11 +630,17 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
620630 replace_inputs .update ({input_rv : input_rv .type () for input_rv in input_rvs })
621631 cloned_outputs = clone_replace (outputs , replace = replace_inputs )
622632
623- marginalization_op = FiniteDiscreteMarginalRV (
633+ if isinstance (rv_to_marginalize .owner .op , DiscreteMarkovChain ):
634+ marginalize_constructor = DiscreteMarginalMarkovChainRV
635+ else :
636+ marginalize_constructor = FiniteDiscreteMarginalRV
637+
638+ marginalization_op = marginalize_constructor (
624639 inputs = list (replace_inputs .values ()),
625640 outputs = cloned_outputs ,
626641 ndim_supp = ndim_supp ,
627642 )
643+
628644 marginalized_rvs = marginalization_op (* replace_inputs .keys ())
629645 fgraph .replace_all (tuple (zip (rvs_to_marginalize , marginalized_rvs )))
630646 return rvs_to_marginalize , marginalized_rvs
@@ -640,14 +656,17 @@ def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> Tuple[int, ...]:
640656 elif isinstance (op , DiscreteUniform ):
641657 lower , upper = constant_fold (rv .owner .inputs [3 :])
642658 return tuple (range (lower , upper + 1 ))
659+ elif isinstance (op , DiscreteMarkovChain ):
660+ P = rv .owner .inputs [0 ]
661+ return tuple (range (pt .get_vector_length (P [- 1 ])))
643662
644663 raise NotImplementedError (f"Cannot compute domain for op { op } " )
645664
646665
647666def _add_reduce_batch_dependent_logps (
648667 marginalized_type : TensorType , dependent_logps : Sequence [TensorVariable ]
649668):
650- """Add the logps of dependent RVs while reducing extra batch dims as assessed from the `marginalized_type`."""
669+ """Add the logps of dependent RVs while reducing extra batch dims relative to `marginalized_type`."""
651670
652671 mbcast = marginalized_type .broadcastable
653672 reduced_logps = []
@@ -730,3 +749,69 @@ def logp_fn(marginalized_rv_const, *non_sequences):
730749
731750 # We have to add dummy logps for the remaining value variables, otherwise PyMC will raise
732751 return joint_logps , * (pt .constant (0 ),) * (len (values ) - 1 )
752+
753+
754+ @_logprob .register (DiscreteMarginalMarkovChainRV )
755+ def marginal_hmm_logp (op , values , * inputs , ** kwargs ):
756+
757+ marginalized_rvs_node = op .make_node (* inputs )
758+ inner_rvs = clone_replace (
759+ op .inner_outputs ,
760+ replace = {u : v for u , v in zip (op .inner_inputs , marginalized_rvs_node .inputs )},
761+ )
762+
763+ chain_rv , * dependent_rvs = inner_rvs
764+ P , n_steps_ , init_dist_ , rng = chain_rv .owner .inputs
765+ domain = pt .arange (P .shape [- 1 ], dtype = "int32" )
766+
767+ # Construct logp in two steps
768+ # Step 1: Compute the probability of the data ("emissions") under every possible state (vec_logp_emission)
769+
770+ # First we need to vectorize the conditional logp graph of the data, in case there are batch dimensions floating
771+ # around. To do this, we need to break the dependency between chain and the init_dist_ random variable. Otherwise,
772+ # PyMC will detect a random variable in the logp graph (init_dist_), that isn't relevant at this step.
773+ chain_value = chain_rv .clone ()
774+ dependent_rvs = clone_replace (dependent_rvs , {chain_rv : chain_value })
775+ logp_emissions_dict = conditional_logp (dict (zip (dependent_rvs , values )))
776+
777+ # Reduce and add the batch dims beyond the chain dimension
778+ reduced_logp_emissions = _add_reduce_batch_dependent_logps (
779+ chain_rv .type , logp_emissions_dict .values ()
780+ )
781+
782+ # Add a batch dimension for the domain of the chain
783+ chain_shape = constant_fold (tuple (chain_rv .shape ))
784+ batch_chain_value = pt .moveaxis (pt .full ((* chain_shape , domain .size ), domain ), - 1 , 0 )
785+ batch_logp_emissions = vectorize_graph (reduced_logp_emissions , {chain_value : batch_chain_value })
786+
787+ # Step 2: Compute the transition probabilities
788+ # This is the "forward algorithm", alpha_t = p(y | s_t) * sum_{s_{t-1}}(p(s_t | s_{t-1}) * alpha_{t-1})
789+ # We do it entirely in logs, though.
790+
791+ # To compute the prior probabilities of each state, we evaluate the logp of the domain (all possible states) under
792+ # the initial distribution. This is robust to everything the user can throw at it.
793+ batch_logp_init_dist = pt .vectorize (lambda x : logp (init_dist_ , x ), "()->()" )(
794+ batch_chain_value [..., 0 ]
795+ )
796+ log_alpha_init = batch_logp_init_dist + batch_logp_emissions [..., 0 ]
797+
798+ def step_alpha (logp_emission , log_alpha , log_P ):
799+ step_log_prob = pt .logsumexp (log_alpha [:, None ] + log_P , axis = 0 )
800+ return logp_emission + step_log_prob
801+
802+ P_bcast_dims = (len (chain_shape ) - 1 ) - (P .type .ndim - 2 )
803+ log_P = pt .shape_padright (pt .log (P ), P_bcast_dims )
804+ log_alpha_seq , _ = scan (
805+ step_alpha ,
806+ non_sequences = [log_P ],
807+ outputs_info = [log_alpha_init ],
808+ # Scan needs the time dimension first, and we already consumed the 1st logp computing the initial value
809+ sequences = pt .moveaxis (batch_logp_emissions [..., 1 :], - 1 , 0 ),
810+ )
811+ # Final logp is just the sum of the last scan state
812+ joint_logp = pt .logsumexp (log_alpha_seq [- 1 ], axis = 0 )
813+
814+ # If there are multiple emission streams, we have to add dummy logps for the remaining value variables. The first
815+ # return is the joint probability of everything together, but PyMC still expects one logp for each one.
816+ dummy_logps = (pt .constant (0 ),) * (len (values ) - 1 )
817+ return joint_logp , * dummy_logps
0 commit comments