2626 vectorize_graph ,
2727)
2828from pytensor .scan import map as scan_map
29- from pytensor .tensor import TensorVariable
29+ from pytensor .tensor import TensorType , TensorVariable
3030from pytensor .tensor .elemwise import Elemwise
3131from pytensor .tensor .shape import Shape
3232from pytensor .tensor .special import log_softmax
@@ -381,41 +381,36 @@ def transform_input(inputs):
381381
382382 rv_dict = {}
383383 rv_dims = {}
384- for seed , rv in zip (seeds , vars_to_recover ):
384+ for seed , marginalized_rv in zip (seeds , vars_to_recover ):
385385 supported_dists = (Bernoulli , Categorical , DiscreteUniform )
386- if not isinstance (rv .owner .op , supported_dists ):
386+ if not isinstance (marginalized_rv .owner .op , supported_dists ):
387387 raise NotImplementedError (
388- f"RV with distribution { rv .owner .op } cannot be recovered. "
388+ f"RV with distribution { marginalized_rv .owner .op } cannot be recovered. "
389389 f"Supported distribution include { supported_dists } "
390390 )
391391
392392 m = self .clone ()
393- rv = m .vars_to_clone [rv ]
394- m .unmarginalize ([rv ])
395- dependent_vars = find_conditional_dependent_rvs (rv , m .basic_RVs )
396- joint_logps = m .logp (vars = dependent_vars + [ rv ] , sum = False )
393+ marginalized_rv = m .vars_to_clone [marginalized_rv ]
394+ m .unmarginalize ([marginalized_rv ])
395+ dependent_vars = find_conditional_dependent_rvs (marginalized_rv , m .basic_RVs )
396+ joint_logps = m .logp (vars = [ marginalized_rv ] + dependent_vars , sum = False )
397397
398- marginalized_value = m .rvs_to_values [rv ]
398+ marginalized_value = m .rvs_to_values [marginalized_rv ]
399399 other_values = [v for v in m .value_vars if v is not marginalized_value ]
400400
401401 # Handle batch dims for marginalized value and its dependent RVs
402- joint_logp = joint_logps [- 1 ]
403- for dv in joint_logps [:- 1 ]:
404- dbcast = dv .type .broadcastable
405- mbcast = marginalized_value .type .broadcastable
406- mbcast = (True ,) * (len (dbcast ) - len (mbcast )) + mbcast
407- values_axis_bcast = [
408- i for i , (m , v ) in enumerate (zip (mbcast , dbcast )) if m and not v
409- ]
410- joint_logp += dv .sum (values_axis_bcast )
402+ marginalized_logp , * dependent_logps = joint_logps
403+ joint_logp = marginalized_logp + _add_reduce_batch_dependent_logps (
404+ marginalized_rv .type , dependent_logps
405+ )
411406
412- rv_shape = constant_fold (tuple (rv .shape ))
413- rv_domain = get_domain_of_finite_discrete_rv (rv )
407+ rv_shape = constant_fold (tuple (marginalized_rv .shape ))
408+ rv_domain = get_domain_of_finite_discrete_rv (marginalized_rv )
414409 rv_domain_tensor = pt .moveaxis (
415410 pt .full (
416411 (* rv_shape , len (rv_domain )),
417412 rv_domain ,
418- dtype = rv .dtype ,
413+ dtype = marginalized_rv .dtype ,
419414 ),
420415 - 1 ,
421416 0 ,
@@ -431,7 +426,7 @@ def transform_input(inputs):
431426 joint_logps_norm = log_softmax (joint_logps , axis = - 1 )
432427 if return_samples :
433428 sample_rv_outs = pymc .Categorical .dist (logit_p = joint_logps )
434- if isinstance (rv .owner .op , DiscreteUniform ):
429+ if isinstance (marginalized_rv .owner .op , DiscreteUniform ):
435430 sample_rv_outs += rv_domain [0 ]
436431
437432 rv_loglike_fn = compile_pymc (
@@ -456,18 +451,20 @@ def transform_input(inputs):
456451 logps , samples = zip (* logvs )
457452 logps = np .array (logps )
458453 samples = np .array (samples )
459- rv_dict [rv .name ] = samples .reshape (
454+ rv_dict [marginalized_rv .name ] = samples .reshape (
460455 tuple (len (coord ) for coord in stacked_dims .values ()) + samples .shape [1 :],
461456 )
462457 else :
463458 logps = np .array (logvs )
464459
465- rv_dict ["lp_" + rv .name ] = logps .reshape (
460+ rv_dict ["lp_" + marginalized_rv .name ] = logps .reshape (
466461 tuple (len (coord ) for coord in stacked_dims .values ()) + logps .shape [1 :],
467462 )
468- if rv .name in m .named_vars_to_dims :
469- rv_dims [rv .name ] = list (m .named_vars_to_dims [rv .name ])
470- rv_dims ["lp_" + rv .name ] = rv_dims [rv .name ] + ["lp_" + rv .name + "_dim" ]
463+ if marginalized_rv .name in m .named_vars_to_dims :
464+ rv_dims [marginalized_rv .name ] = list (m .named_vars_to_dims [marginalized_rv .name ])
465+ rv_dims ["lp_" + marginalized_rv .name ] = rv_dims [marginalized_rv .name ] + [
466+ "lp_" + marginalized_rv .name + "_dim"
467+ ]
471468
472469 coords , dims = coords_and_dims_for_inferencedata (self )
473470 dims .update (rv_dims )
@@ -647,6 +644,22 @@ def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> Tuple[int, ...]:
647644 raise NotImplementedError (f"Cannot compute domain for op { op } " )
648645
649646
647+ def _add_reduce_batch_dependent_logps (
648+ marginalized_type : TensorType , dependent_logps : Sequence [TensorVariable ]
649+ ):
650+ """Add the logps of dependent RVs while reducing extra batch dims as assessed from the `marginalized_type`."""
651+
652+ mbcast = marginalized_type .broadcastable
653+ reduced_logps = []
654+ for dependent_logp in dependent_logps :
655+ dbcast = dependent_logp .type .broadcastable
656+ dim_diff = len (dbcast ) - len (mbcast )
657+ mbcast_aligned = (True ,) * dim_diff + mbcast
658+ vbcast_axis = [i for i , (m , v ) in enumerate (zip (mbcast_aligned , dbcast )) if m and not v ]
659+ reduced_logps .append (dependent_logp .sum (vbcast_axis ))
660+ return pt .add (* reduced_logps )
661+
662+
650663@_logprob .register (FiniteDiscreteMarginalRV )
651664def finite_discrete_marginal_rv_logp (op , values , * inputs , ** kwargs ):
652665 # Clone the inner RV graph of the Marginalized RV
@@ -662,17 +675,12 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
662675 logps_dict = conditional_logp (rv_values = inner_rvs_to_values , ** kwargs )
663676
664677 # Reduce logp dimensions corresponding to broadcasted variables
665- joint_logp = logps_dict [inner_rvs_to_values [marginalized_rv ]]
666- for inner_rv , inner_value in inner_rvs_to_values .items ():
667- if inner_rv is marginalized_rv :
668- continue
669- vbcast = inner_value .type .broadcastable
670- mbcast = marginalized_rv .type .broadcastable
671- mbcast = (True ,) * (len (vbcast ) - len (mbcast )) + mbcast
672- values_axis_bcast = [i for i , (m , v ) in enumerate (zip (mbcast , vbcast )) if m != v ]
673- joint_logp += logps_dict [inner_value ].sum (values_axis_bcast , keepdims = True )
674-
675- # Wrap the joint_logp graph in an OpFromGrah, so that we can evaluate it at different
678+ marginalized_logp = logps_dict .pop (inner_rvs_to_values [marginalized_rv ])
679+ joint_logp = marginalized_logp + _add_reduce_batch_dependent_logps (
680+ marginalized_rv .type , logps_dict .values ()
681+ )
682+
683+ # Wrap the joint_logp graph in an OpFromGraph, so that we can evaluate it at different
676684 # values of the marginalized RV
677685 # Some inputs are not root inputs (such as transformed projections of value variables)
678686 # Or cannot be used as inputs to an OpFromGraph (shared variables and constants)
@@ -700,6 +708,7 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
700708 )
701709
702710 # Arbitrary cutoff to switch to Scan implementation to keep graph size under control
711+ # TODO: Try vectorize here
703712 if len (marginalized_rv_domain ) <= 10 :
704713 joint_logps = [
705714 joint_logp_op (marginalized_rv_domain_tensor [i ], * values , * inputs )
0 commit comments