@@ -212,8 +212,10 @@ def logp(self, vars=None, **kwargs):
212212 return m ._logp (vars = vars , ** kwargs )
213213
214214 def clone (self ):
215- m = MarginalModel ()
216- vars = self .basic_RVs + self .potentials + self .deterministics + self .marginalized_rvs
215+ m = MarginalModel (coords = self .coords )
216+ model_vars = self .basic_RVs + self .potentials + self .deterministics + self .marginalized_rvs
217+ data_vars = [var for name , var in self .named_vars .items () if var not in model_vars ]
218+ vars = model_vars + data_vars
217219 cloned_vars = clone_replace (vars )
218220 vars_to_clone = {var : cloned_var for var , cloned_var in zip (vars , cloned_vars )}
219221 m .vars_to_clone = vars_to_clone
@@ -598,7 +600,7 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
598600 # can ultimately be generated that is proportional to the support domain and not
599601 # to the variables dimensions
600602 # We don't need to worry about this if the RV is scalar.
601- if np .prod (constant_fold (tuple (rv_to_marginalize .shape ))) > 1 :
603+ if np .prod (constant_fold (tuple (rv_to_marginalize .shape ), raise_not_constant = False )) != 1 :
602604 if not is_elemwise_subgraph (rv_to_marginalize , dependent_rvs_input_rvs , dependent_rvs ):
603605 raise NotImplementedError (
604606 "The subgraph between a marginalized RV and its dependents includes non Elemwise operations. "
@@ -682,7 +684,7 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
682684 # batched dimensions of the marginalized RV
683685
684686 # PyMC does not allow RVs in the logp graph, even if we are just using the shape
685- marginalized_rv_shape = constant_fold (tuple (marginalized_rv .shape ))
687+ marginalized_rv_shape = constant_fold (tuple (marginalized_rv .shape ), raise_not_constant = False )
686688 marginalized_rv_domain = get_domain_of_finite_discrete_rv (marginalized_rv )
687689 marginalized_rv_domain_tensor = pt .moveaxis (
688690 pt .full (
0 commit comments