diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index de341c68cd..3e2dea607f 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -1648,7 +1648,7 @@ def model_logp_fn(ip: PointType) -> np.ndarray: compile_kwargs=compile_kwargs, ) approx_sample = approx.sample( - draws=chains, random_seed=random_seed_list[0], return_inferencedata=False + draws=chains, model=model, random_seed=random_seed_list[0], return_inferencedata=False ) initial_points = [approx_sample[i] for i in range(chains)] std_apoint = approx.std.eval() @@ -1672,7 +1672,7 @@ def model_logp_fn(ip: PointType) -> np.ndarray: compile_kwargs=compile_kwargs, ) approx_sample = approx.sample( - draws=chains, random_seed=random_seed_list[0], return_inferencedata=False + draws=chains, model=model, random_seed=random_seed_list[0], return_inferencedata=False ) initial_points = [approx_sample[i] for i in range(chains)] cov = approx.std.eval() ** 2 @@ -1690,7 +1690,7 @@ def model_logp_fn(ip: PointType) -> np.ndarray: compile_kwargs=compile_kwargs, ) approx_sample = approx.sample( - draws=chains, random_seed=random_seed_list[0], return_inferencedata=False + draws=chains, model=model, random_seed=random_seed_list[0], return_inferencedata=False ) initial_points = [approx_sample[i] for i in range(chains)] cov = approx.std.eval() ** 2 diff --git a/pymc/variational/approximations.py b/pymc/variational/approximations.py index 29b7093108..2943db0544 100644 --- a/pymc/variational/approximations.py +++ b/pymc/variational/approximations.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import cached_property + import numpy as np import pytensor @@ -32,7 +34,6 @@ Group, NotImplementedInference, _known_scan_ignored_inputs, - node_property, ) __all__ = ["Empirical", "FullRank", "MeanField", "sample_approx"] @@ -52,20 +53,20 @@ class MeanFieldGroup(Group): short_name = "mean_field" alias_names = frozenset(["mf"]) - @node_property + @cached_property def mean(self): return self.params_dict["mu"] - @node_property + @cached_property def rho(self): return self.params_dict["rho"] - @node_property + @cached_property def cov(self): var = rho2sigma(self.rho) ** 2 return pt.diag(var) - @node_property + @cached_property def std(self): return rho2sigma(self.rho) @@ -85,6 +86,13 @@ def create_shared_params(self, start=None, start_sigma=None): # by `self.ordering`. In the cases I looked into these turn out to be the same, but there may be edge cases or # future code changes that break this assumption. start = self._prepare_start(start) + # Ensure start is a 1D array and matches ddim + start = np.asarray(start).flatten() + if start.size != self.ddim: + raise ValueError( + f"Start array size mismatch: got {start.size}, expected {self.ddim}. " + f"Start shape: {start.shape if hasattr(start, 'shape') else 'unknown'}" + ) rho1 = np.zeros((self.ddim,)) if start_sigma is not None: @@ -99,14 +107,14 @@ def create_shared_params(self, start=None, start_sigma=None): "rho": pytensor.shared(pm.floatX(rho), "rho"), } - @node_property + @cached_property def symbolic_random(self): initial = self.symbolic_initial sigma = self.std mu = self.mean return sigma * initial + mu - @node_property + @cached_property def symbolic_logq_not_scaled(self): z0 = self.symbolic_initial std = rho2sigma(self.rho) @@ -139,11 +147,18 @@ def __init_group__(self, group): def create_shared_params(self, start=None): start = self._prepare_start(start) + # Ensure start is a 1D array and matches ddim + start = np.asarray(start).flatten() + if start.size != self.ddim: + raise ValueError( + f"Start array size mismatch: got {start.size}, expected {self.ddim}. " + f"Start shape: {start.shape if hasattr(start, 'shape') else 'unknown'}" + ) n = self.ddim L_tril = np.eye(n)[np.tril_indices(n)].astype(pytensor.config.floatX) return {"mu": pytensor.shared(start, "mu"), "L_tril": pytensor.shared(L_tril, "L_tril")} - @node_property + @cached_property def L(self): L = pt.zeros((self.ddim, self.ddim)) L = pt.set_subtensor(L[self.tril_indices], self.params_dict["L_tril"]) @@ -151,16 +166,16 @@ def L(self): L = pt.set_subtensor(Ld, rho2sigma(Ld)) return L - @node_property + @cached_property def mean(self): return self.params_dict["mu"] - @node_property + @cached_property def cov(self): L = self.L return L.dot(L.T) - @node_property + @cached_property def std(self): return pt.sqrt(pt.diag(self.cov)) @@ -173,7 +188,7 @@ def num_tril_entries(self): def tril_indices(self): return np.tril_indices(self.ddim) - @node_property + @cached_property def symbolic_logq_not_scaled(self): z0 = self.symbolic_initial diag = pt.diagonal(self.L, 0, self.L.ndim - 2, self.L.ndim - 1) @@ -182,7 +197,7 @@ def symbolic_logq_not_scaled(self): logq = quaddist - logdet return logq.sum(range(1, logq.ndim)) - @node_property + @cached_property def symbolic_random(self): initial = self.symbolic_initial L = self.L @@ -233,6 +248,8 @@ def create_shared_params(self, trace=None, size=None, jitter=1, start=None): return {"histogram": pytensor.shared(pm.floatX(histogram), "histogram")} def _check_trace(self): + from pymc.model import modelcontext + trace = self._kwargs.get("trace", None) if isinstance(trace, InferenceData): raise NotImplementedError( @@ -240,10 +257,10 @@ def _check_trace(self): " Pass `pm.sample(return_inferencedata=False)` to get a `MultiTrace` to use with `Empirical`." " Please help us to refactor: https://github.com/pymc-devs/pymc/issues/5884" ) - elif trace is not None and not all( - self.model.rvs_to_values[var].name in trace.varnames for var in self.group - ): - raise ValueError("trace has not all free RVs in the group") + elif trace is not None: + model = modelcontext(None) + if not all(model.rvs_to_values[var].name in trace.varnames for var in self.group): + raise ValueError("trace has not all free RVs in the group") def randidx(self, size=None): if size is None: @@ -284,24 +301,24 @@ def _new_initial(self, size, deterministic, more_replacements=None): else: return self.histogram[self.randidx(size)] - @property + @cached_property def symbolic_random(self): return self.symbolic_initial - @property + @cached_property def histogram(self): return self.params_dict["histogram"] - @node_property + @cached_property def mean(self): return self.histogram.mean(0) - @node_property + @cached_property def cov(self): x = self.histogram - self.mean return x.T.dot(x) / pm.floatX(self.histogram.shape[0]) - @node_property + @cached_property def std(self): return pt.sqrt(pt.diag(self.cov)) diff --git a/pymc/variational/operators.py b/pymc/variational/operators.py index 502fe13ab9..951f521d51 100644 --- a/pymc/variational/operators.py +++ b/pymc/variational/operators.py @@ -19,6 +19,7 @@ import pymc as pm +from pymc.model import modelcontext from pymc.variational import opvi from pymc.variational.opvi import ( NotImplementedInference, @@ -142,7 +143,8 @@ def __init__(self, approx, temperature=1): def apply(self, f): # f: kernel function for KSD f(histogram) -> (k(x,.), \nabla_x k(x,.)) - if _known_scan_ignored_inputs([self.approx.model.logp()]): + model = modelcontext(None) + if _known_scan_ignored_inputs([model.logp()]): raise NotImplementedInference( "SVGD does not currently support Minibatch or Simulator RV" ) diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index 3cd5cc3dcf..7d480b0fdc 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -52,6 +52,8 @@ import itertools import warnings +from dataclasses import dataclass +from functools import cached_property from typing import Any, overload import numpy as np @@ -70,7 +72,7 @@ from pymc.backends.ndarray import NDArray from pymc.blocking import DictToArrayBijection from pymc.initial_point import make_initial_point_fn -from pymc.model import modelcontext +from pymc.model import Model, modelcontext from pymc.pytensorf import ( SeedSequenceSeed, compile, @@ -78,13 +80,7 @@ find_rng_nodes, reseed_rngs, ) -from pymc.util import ( - RandomState, - WithMemoization, - _get_seeds_per_chain, - locally_cachedmethod, - makeiter, -) +from pymc.util import RandomState, _get_seeds_per_chain, makeiter, point_wrapper from pymc.variational.minibatch_rv import MinibatchRandomVariable, get_scaling from pymc.variational.updates import adagrad_window from pymc.vartypes import discrete_types @@ -143,21 +139,6 @@ def inner(*args, **kwargs): return wrap -def node_property(f): - """Wrap method to accessible tensor.""" - if isinstance(f, str): - - def wrapper(fn): - ff = append_name(f)(fn) - f_ = pytensor.config.change_flags(compute_test_value="off")(ff) - return property(locally_cachedmethod(f_)) - - return wrapper - else: - f_ = pytensor.config.change_flags(compute_test_value="off")(f) - return property(locally_cachedmethod(f_)) - - @pytensor.config.change_flags(compute_test_value="ignore") def try_to_set_test_value(node_in, node_out, s): _s = s @@ -497,7 +478,7 @@ def __init__(self, approx): varlogp_norm = property(lambda self: self.approx.varlogp_norm) datalogp_norm = property(lambda self: self.approx.datalogp_norm) logq_norm = property(lambda self: self.approx.logq_norm) - model = property(lambda self: self.approx.model) + model = property(lambda self: modelcontext(None)) def apply(self, f): # pragma: no cover R"""Operator itself. @@ -587,7 +568,7 @@ def from_function(cls, f): return obj -class Group(WithMemoization): +class Group: R"""**Base class for grouping variables in VI**. Grouped Approximation is used for modelling mutual dependencies @@ -770,7 +751,6 @@ def __init__( self._vfam = vfam self.rng = np.random.RandomState(random_seed) model = modelcontext(model) - self.model = model self.group = group self.user_params = params self._user_params = None @@ -783,17 +763,42 @@ def __init__( self.__init_group__(self.group) def _prepare_start(self, start=None): + model = modelcontext(None) + # If start is already an array, we need to ensure it's flattened and matches ddim + if isinstance(start, np.ndarray): + start_flat = start.flatten() + if start_flat.size != self.ddim: + raise ValueError( + f"Mismatch in start array size: got {start_flat.size}, expected {self.ddim}. " + f"Start array shape: {start.shape}, flattened size: {start_flat.size}" + ) + return start_flat + # Otherwise, get initial point from model and filter by group variables ipfn = make_initial_point_fn( - model=self.model, + model=model, overrides=start, jitter_rvs={}, return_transformed=True, ) start = ipfn(self.rng.randint(2**30, dtype=np.int64)) - group_vars = {self.model.rvs_to_values[v].name for v in self.group} + group_vars = {model.rvs_to_values[v].name for v in self.group} start = {k: v for k, v in start.items() if k in group_vars} - start = DictToArrayBijection.map(start).data - return start + if not start: + raise ValueError( + f"No matching variables found in initial point for group variables: {group_vars}. " + f"Initial point keys: {list(ipfn(self.rng.randint(2**30, dtype=np.int64)).keys())}" + ) + start_raveled = DictToArrayBijection.map(start) + # Ensure we have a 1D array that matches self.ddim + start_data = start_raveled.data + expected_size = self.ddim + if start_data.size != expected_size: + raise ValueError( + f"Mismatch in start array size: got {start_data.size}, expected {expected_size}. " + f"Group variables: {group_vars}, Start dict keys: {list(start.keys())}, " + f"This might indicate an issue with the model context or group initialization." + ) + return start_data @classmethod def get_param_spec_for(cls, **kwargs): @@ -867,9 +872,11 @@ def __init_group__(self, group): """Initialize the group.""" if not group: raise GroupError("Got empty group") + model = modelcontext(None) + if self.group is None: - # delayed init - self.group = group + self.group = list(group) + self.symbolic_initial = self._initial_type( self.__class__.__name__ + "_symbolic_initial_tensor" ) @@ -878,15 +885,18 @@ def __init_group__(self, group): # so I have to to it by myself # 1) we need initial point (transformed space) - model_initial_point = self.model.initial_point(0) + model_initial_point = model.initial_point(0) # 2) we'll work with a single group, a subset of the model # here we need to create a mapping to replace value_vars with slices from the approximation + # Clear old replacements/ordering before rebuilding + self.replacements = collections.OrderedDict() + self.ordering = collections.OrderedDict() start_idx = 0 for var in self.group: if var.type.numpy_dtype.name in discrete_types: raise ParametrizationError(f"Discrete variables are not supported by VI: {var}") # 3) This is the way to infer shape and dtype of the variable - value_var = self.model.rvs_to_values[var] + value_var = model.rvs_to_values[var] test_var = model_initial_point[value_var.name] shape = test_var.shape size = test_var.size @@ -911,8 +921,9 @@ def params_dict(self): # prefixed are correctly reshaped if self._user_params is not None: return self._user_params - else: - return self.shared_params + if self.shared_params is None: + raise ParametrizationError("Group parameters have not been initialized") + return self.shared_params @property def params(self): @@ -940,14 +951,6 @@ def _new_initial_shape(self, size, dim, more_replacements=None): """ return pt.stack([size, dim]) - @node_property - def ndim(self): - return self.ddim - - @property - def ddim(self): - return sum(s.stop - s.start for _, s, _, _ in self.ordering.values()) - def _new_initial(self, size, deterministic, more_replacements=None): """*Dev* - allocates new initial random generator. @@ -993,7 +996,15 @@ def _new_initial(self, size, deterministic, more_replacements=None): initial = pt.switch(deterministic, pt.ones(shape, dtype) * dist_map, sample) return initial - @node_property + @property + def ndim(self): + return self.ddim + + @property + def ddim(self): + return sum(s.stop - s.start for _, s, _, _ in self.ordering.values()) + + @cached_property def symbolic_random(self): """*Dev* - abstract node that takes `self.symbolic_initial` and creates approximate posterior that is parametrized with `self.params_dict`. @@ -1100,7 +1111,7 @@ def make_size_and_deterministic_replacements(self, s, d, more_replacements=None) initial = graph_replace(initial, more_replacements, strict=False) return {self.symbolic_initial: initial} - @node_property + @cached_property def symbolic_normalizing_constant(self): """*Dev* - normalizing constant for `self.logq`, scales it to `minibatch_size` instead of `total_size`.""" t = self.to_flat_input( @@ -1119,45 +1130,37 @@ def symbolic_normalizing_constant(self): t = self.symbolic_single_sample(t) return pm.floatX(t) - @node_property + @property def symbolic_logq_not_scaled(self): """*Dev* - symbolically computed logq for `self.symbolic_random` computations can be more efficient since all is known beforehand including `self.symbolic_random`.""" raise NotImplementedError # shape (s,) - @node_property + @cached_property def symbolic_logq(self): """*Dev* - correctly scaled `self.symbolic_logq_not_scaled`.""" return self.symbolic_logq_not_scaled - @node_property + @cached_property def logq(self): """*Dev* - Monte Carlo estimate for group `logQ`.""" return self.symbolic_logq.mean(0) - @node_property + @cached_property def logq_norm(self): """*Dev* - Monte Carlo estimate for group `logQ` normalized.""" return self.logq / self.symbolic_normalizing_constant - def __str__(self): - """Return a string representation for the object.""" - if self.group is None: - shp = "undefined" - else: - shp = str(self.ddim) - return f"{self.__class__.__name__}[{shp}]" - - @node_property + @property def std(self) -> pt.TensorVariable: """Return the standard deviation of the latent variables as an unstructured 1-dimensional tensor variable.""" raise NotImplementedError() - @node_property + @property def cov(self) -> pt.TensorVariable: """Return the covariance between the latent variables as an unstructured 2-dimensional tensor variable.""" raise NotImplementedError() - @node_property + @property def mean(self) -> pt.TensorVariable: """Return the mean of the latent variables as an unstructured 1-dimensional tensor variable.""" raise NotImplementedError() @@ -1166,12 +1169,13 @@ def var_to_data(self, shared: pt.TensorVariable) -> xarray.Dataset: """Take a flat 1-dimensional tensor variable and maps it to an xarray data set based on the information in `self.ordering`.""" # This is somewhat similar to `DictToArrayBijection.rmap`, which doesn't work here since we don't have # `RaveledVars` and need to take the information from `self.ordering` instead + model = modelcontext(None) shared_nda = shared.eval() result = {} for name, s, shape, dtype in self.ordering.values(): - dims = self.model.named_vars_to_dims.get(name, None) + dims = model.named_vars_to_dims.get(name, None) if dims is not None: - coords = {d: np.array(self.model.coords[d]) for d in dims} + coords = {d: np.array(model.coords[d]) for d in dims} else: coords = None values = shared_nda[s].reshape(shape).astype(dtype) @@ -1193,7 +1197,13 @@ def std_data(self) -> xarray.Dataset: group_for_short_name = Group.group_for_short_name -class Approximation(WithMemoization): +@dataclass +class TraceSpec: + sample_vars: list + test_point: collections.OrderedDict + + +class Approximation: """**Wrapper for grouped approximations**. Wraps list of groups, creates an Approximation instance that collects @@ -1219,6 +1229,10 @@ class Approximation(WithMemoization): :class:`Group` """ + def __setstate__(self, state): + """Restore state after unpickling.""" + self.__dict__.update(state) + def __init__(self, groups, model=None): self._scale_cost_to_minibatch = pytensor.shared(np.int8(1)) model = modelcontext(model) @@ -1227,34 +1241,165 @@ def __init__(self, groups, model=None): self.groups = [] seen = set() rest = None - for g in groups: - if g.group is None: - if rest is not None: - raise GroupError("More than one group is specified for the rest variables") - else: + with model: + for g in groups: + if g.group is None: + if rest is not None: + raise GroupError("More than one group is specified for the rest variables") rest = g - else: - if set(g.group) & seen: - raise GroupError("Found duplicates in groups") - seen.update(g.group) - self.groups.append(g) - # List iteration to preserve order for reproducibility between runs - unseen_free_RVs = [var for var in model.free_RVs if var not in seen] - if unseen_free_RVs: - if rest is None: - raise GroupError("No approximation is specified for the rest variables") - else: + else: + group_vars = list(g.group) + missing = [var for var in group_vars if var not in model.free_RVs] + if missing: + names = ", ".join(var.name for var in missing) + raise GroupError(f"Variables [{names}] are not part of the provided model") + if set(group_vars) & seen: + raise GroupError("Found duplicates in groups") + seen.update(group_vars) + self.groups.append(g) + # List iteration to preserve order for reproducibility between runs + unseen_free_RVs = [var for var in model.free_RVs if var not in seen] + if unseen_free_RVs: + if rest is None: + raise GroupError("No approximation is specified for the rest variables") rest.__init_group__(unseen_free_RVs) self.groups.append(rest) - self.model = model @property def has_logq(self): return all(self.collect("has_logq")) + @property + def model(self): + warnings.warn( + "`model` field is deprecated and will be removed in future versions. Use " + "a model context instead.", + DeprecationWarning, + ) + return modelcontext(None) + def collect(self, item): return [getattr(g, item) for g in self.groups] + def _variational_orderings(self, model): + orderings = collections.OrderedDict() + for g in self.groups: + orderings.update(g.ordering) + return orderings + + def _draw_variational_samples(self, model, names, draws, size_sym, random_seed): + with model: + if not names: + return {} + tensors = [self.rslice(name, model) for name in names] + tensors = self.set_size_and_deterministic(tensors, size_sym, 0) + sample_fn = compile([size_sym], tensors) + rng_nodes = find_rng_nodes(tensors) + if random_seed is not None: + reseed_rngs(rng_nodes, random_seed) + outputs = sample_fn(draws) + if not isinstance(outputs, list | tuple): + outputs = [outputs] + return dict(zip(names, outputs)) + + def _draw_forward_samples(self, model, approx_samples, approx_names, draws, random_seed): + from pymc.sampling.forward import compile_forward_sampling_function + + with model: + model_names = {model.rvs_to_values[v].name: v for v in model.free_RVs} + forward_names = sorted(name for name in model_names if name not in approx_names) + if not forward_names: + return {} + + forward_vars = [model_names[name] for name in forward_names] + approx_vars = [model_names[name] for name in approx_names if name in model_names] + sampler_fn, _ = compile_forward_sampling_function( + outputs=forward_vars, + vars_in_trace=approx_vars, + basic_rvs=model.basic_RVs, + givens_dict=None, + random_seed=random_seed, + ) + approx_value_vars = [model.rvs_to_values[var] for var in approx_vars] + input_values = {var.name: approx_samples[var.name] for var in approx_value_vars} + wrapped_sampler = point_wrapper(sampler_fn) + + stacked = {name: [] for name in forward_names} + for i in range(draws): + inputs = {name: values[i] for name, values in input_values.items()} + raw = wrapped_sampler(**inputs) + if not isinstance(raw, list | tuple): + raw = [raw] + for name, value in zip(forward_names, raw): + stacked[name].append(value) + return {name: np.stack(values) for name, values in stacked.items()} + + def _collect_sample_vars(self, model, sample_names): + lookup = {} + for var in model.unobserved_value_vars: + lookup.setdefault(var.name, var) + for name, var in model.named_vars.items(): + lookup.setdefault(name, var) + sample_vars = [lookup[name] for name in sample_names if name in lookup] + seen = {var.name for var in sample_vars} + for var in model.unobserved_value_vars: + if var.name not in seen: + sample_vars.append(var) + return sample_vars, lookup + + def _compute_missing_trace_values(self, model, samples, missing_vars): + with model: + if not missing_vars: + return {} + input_vars = model.value_vars + base_point = model.initial_point() + point = { + var.name: np.asarray(samples[var.name][0]) + if var.name in samples + else base_point[var.name] + for var in input_vars + if var.name in samples or var.name in base_point + } + compute_fn = model.compile_fn( + missing_vars, + inputs=input_vars, + on_unused_input="ignore", + point_fn=True, + ) + raw_values = compute_fn(point) + if not isinstance(raw_values, list | tuple): + raw_values = [raw_values] + values = {var.name: np.asarray(value) for var, value in zip(missing_vars, raw_values)} + return values + + def _build_trace_spec(self, model, samples): + sample_names = sorted(samples.keys()) + sample_vars, _ = self._collect_sample_vars(model, sample_names) + initial_point = model.initial_point() + test_point = collections.OrderedDict() + missing_vars = [] + + for var in sample_vars: + trace_name = var.name + if trace_name in samples: + first_sample = np.asarray(samples[trace_name][0]) + test_point[trace_name] = first_sample + continue + if trace_name in initial_point: + value = np.asarray(initial_point[trace_name]) + test_point[trace_name] = value + continue + missing_vars.append(var) + + values = self._compute_missing_trace_values(model, samples, missing_vars) + for name, value in values.items(): + test_point[name] = value + + return TraceSpec( + sample_vars=sample_vars, + test_point=test_point, + ) + inputs = property(lambda self: self.collect("input")) symbolic_randoms = property(lambda self: self.collect("symbolic_random")) @@ -1267,12 +1412,13 @@ def scale_cost_to_minibatch(self): def scale_cost_to_minibatch(self, value): self._scale_cost_to_minibatch.set_value(np.int8(bool(value))) - @node_property + @cached_property def symbolic_normalizing_constant(self): """*Dev* - normalizing constant for `self.logq`, scales it to `minibatch_size` instead of `total_size`. Here the effect is controlled by `self.scale_cost_to_minibatch`. """ + model = modelcontext(None) t = pt.max( self.collect("symbolic_normalizing_constant") + [ @@ -1280,98 +1426,98 @@ def symbolic_normalizing_constant(self): obs.owner.inputs[1:], constant_fold([obs.owner.inputs[0].shape], raise_not_constant=False), ) - for obs in self.model.observed_RVs + for obs in model.observed_RVs if isinstance(obs.owner.op, MinibatchRandomVariable) ] ) t = pt.switch(self._scale_cost_to_minibatch, t, pt.constant(1, dtype=t.dtype)) return pm.floatX(t) - @node_property + @cached_property def symbolic_logq(self): """*Dev* - collects `symbolic_logq` for all groups.""" return pt.add(*self.collect("symbolic_logq")) - @node_property + @cached_property def logq(self): """*Dev* - collects `logQ` for all groups.""" return pt.add(*self.collect("logq")) - @node_property + @cached_property def logq_norm(self): """*Dev* - collects `logQ` for all groups and normalizes it.""" return self.logq / self.symbolic_normalizing_constant - @node_property + @cached_property def _sized_symbolic_varlogp_and_datalogp(self): """*Dev* - computes sampled prior term from model via `pytensor.scan`.""" - varlogp_s, datalogp_s = self.symbolic_sample_over_posterior( - [self.model.varlogp, self.model.datalogp] - ) + model = modelcontext(None) + varlogp_s, datalogp_s = self.symbolic_sample_over_posterior([model.varlogp, model.datalogp]) return varlogp_s, datalogp_s # both shape (s,) - @node_property + @cached_property def sized_symbolic_varlogp(self): """*Dev* - computes sampled prior term from model via `pytensor.scan`.""" return self._sized_symbolic_varlogp_and_datalogp[0] # shape (s,) - @node_property + @cached_property def sized_symbolic_datalogp(self): """*Dev* - computes sampled data term from model via `pytensor.scan`.""" return self._sized_symbolic_varlogp_and_datalogp[1] # shape (s,) - @node_property + @cached_property def sized_symbolic_logp(self): """*Dev* - computes sampled logP from model via `pytensor.scan`.""" return self.sized_symbolic_varlogp + self.sized_symbolic_datalogp # shape (s,) - @node_property + @cached_property def logp(self): """*Dev* - computes :math:`E_{q}(logP)` from model via `pytensor.scan` that can be optimized later.""" return self.varlogp + self.datalogp - @node_property + @cached_property def varlogp(self): """*Dev* - computes :math:`E_{q}(prior term)` from model via `pytensor.scan` that can be optimized later.""" return self.sized_symbolic_varlogp.mean(0) - @node_property + @cached_property def datalogp(self): """*Dev* - computes :math:`E_{q}(data term)` from model via `pytensor.scan` that can be optimized later.""" return self.sized_symbolic_datalogp.mean(0) - @node_property + @cached_property def _single_symbolic_varlogp_and_datalogp(self): """*Dev* - computes sampled prior term from model via `pytensor.scan`.""" - varlogp, datalogp = self.symbolic_single_sample([self.model.varlogp, self.model.datalogp]) + model = modelcontext(None) + varlogp, datalogp = self.symbolic_single_sample([model.varlogp, model.datalogp]) return varlogp, datalogp - @node_property + @cached_property def single_symbolic_varlogp(self): """*Dev* - for single MC sample estimate of :math:`E_{q}(prior term)` `pytensor.scan` is not needed and code can be optimized.""" return self._single_symbolic_varlogp_and_datalogp[0] - @node_property + @cached_property def single_symbolic_datalogp(self): """*Dev* - for single MC sample estimate of :math:`E_{q}(data term)` `pytensor.scan` is not needed and code can be optimized.""" return self._single_symbolic_varlogp_and_datalogp[1] - @node_property + @cached_property def single_symbolic_logp(self): """*Dev* - for single MC sample estimate of :math:`E_{q}(logP)` `pytensor.scan` is not needed and code can be optimized.""" return self.single_symbolic_datalogp + self.single_symbolic_varlogp - @node_property + @cached_property def logp_norm(self): """*Dev* - normalized :math:`E_{q}(logP)`.""" return self.logp / self.symbolic_normalizing_constant - @node_property + @cached_property def varlogp_norm(self): """*Dev* - normalized :math:`E_{q}(prior term)`.""" return self.varlogp / self.symbolic_normalizing_constant - @node_property + @cached_property def datalogp_norm(self): """*Dev* - normalized :math:`E_{q}(data term)`.""" return self.datalogp / self.symbolic_normalizing_constant @@ -1482,7 +1628,7 @@ def get_optimization_replacements(self, s, d): return repl @pytensor.config.change_flags(compute_test_value="off") - def sample_node(self, node, size=None, deterministic=False, more_replacements=None): + def sample_node(self, node, size=None, deterministic=False, more_replacements=None, model=None): """Sample given node or nodes over shared posterior. Parameters @@ -1501,62 +1647,69 @@ def sample_node(self, node, size=None, deterministic=False, more_replacements=No sampled node(s) with replacements """ node_in = node - if more_replacements: - node = graph_replace(node, more_replacements, strict=False) - if not isinstance(node, list | tuple): - node = [node] - node = self.model.replace_rvs_by_values(node) - if not isinstance(node_in, list | tuple): - node = node[0] - if size is None: - node_out = self.symbolic_single_sample(node) - else: - node_out = self.symbolic_sample_over_posterior(node) - node_out = self.set_size_and_deterministic(node_out, size, deterministic) - try_to_set_test_value(node_in, node_out, size) - return node_out - def rslice(self, name): + model = modelcontext(model) + with model: + if more_replacements: + node = graph_replace(node, more_replacements, strict=False) + if not isinstance(node, list | tuple): + node = [node] + node = model.replace_rvs_by_values(node) + if not isinstance(node_in, list | tuple): + node = node[0] + if size is None: + node_out = self.symbolic_single_sample(node) + else: + node_out = self.symbolic_sample_over_posterior(node) + node_out = self.set_size_and_deterministic(node_out, size, deterministic) + try_to_set_test_value(node_in, node_out, size) + return node_out + + def rslice(self, name, model=None): """*Dev* - vectorized sampling for named random variable without call to `pytensor.scan`. This node still needs :func:`set_size_and_deterministic` to be evaluated. """ + model = modelcontext(model) - def vars_names(vs): - return {self.model.rvs_to_values[v].name for v in vs} - - for vars_, random, ordering in zip( - self.collect("group"), self.symbolic_randoms, self.collect("ordering") - ): - if name in vars_names(vars_): - name_, slc, shape, dtype = ordering[name] - found = random[..., slc].reshape((random.shape[0], *shape)).astype(dtype) - found.name = name + "_vi_random_slice" - break - else: - raise KeyError(f"{name!r} not found") + with model: + for random, ordering in zip(self.symbolic_randoms, self.collect("ordering")): + if name in ordering: + _name, slc, shape, dtype = ordering[name] + found = random[..., slc].reshape((random.shape[0], *shape)).astype(dtype) + found.name = name + "_vi_random_slice" + break + else: + raise KeyError(f"{name!r} not found") return found - @node_property + @property def sample_dict_fn(self): s = pt.iscalar() - names = [self.model.rvs_to_values[v].name for v in self.model.free_RVs] - sampled = [self.rslice(name) for name in names] - sampled = self.set_size_and_deterministic(sampled, s, 0) - sample_fn = compile([s], sampled) - rng_nodes = find_rng_nodes(sampled) - def inner(draws=100, *, random_seed: SeedSequenceSeed = None): - if random_seed is not None: - reseed_rngs(rng_nodes, random_seed) - _samples = sample_fn(draws) - - return dict(zip(names, _samples)) + def inner(draws=100, *, model=None, random_seed: SeedSequenceSeed = None): + model = modelcontext(model) + with model: + orderings = self._variational_orderings(model) + approx_var_names = sorted(orderings.keys()) + approx_samples = self._draw_variational_samples( + model, approx_var_names, draws, s, random_seed + ) + forward_samples = self._draw_forward_samples( + model, approx_samples, approx_var_names, draws, random_seed + ) + return {**approx_samples, **forward_samples} return inner def sample( - self, draws=500, *, random_seed: RandomState = None, return_inferencedata=True, **kwargs + self, + draws=500, + *, + model: Model | None = None, + random_seed: RandomState = None, + return_inferencedata=True, + **kwargs, ): """Draw samples from variational posterior. @@ -1564,6 +1717,8 @@ def sample( ---------- draws : int Number of random samples. + model : Model (optional if in ``with`` context + Model to be used to generate samples. random_seed : int, RandomState or Generator, optional Seed for the random number generator. return_inferencedata : bool @@ -1574,33 +1729,48 @@ def sample( trace: :class:`pymc.backends.base.MultiTrace` Samples drawn from variational posterior. """ - # TODO: add tests for include_transformed case kwargs["log_likelihood"] = False - if random_seed is not None: - (random_seed,) = _get_seeds_per_chain(random_seed, 1) - samples: dict = self.sample_dict_fn(draws, random_seed=random_seed) - points = ( - {name: np.asarray(records[i]) for name, records in samples.items()} - for i in range(draws) - ) + model = modelcontext(model) - trace = NDArray( - model=self.model, - test_point={name: records[0] for name, records in samples.items()}, - ) - try: - trace.setup(draws=draws, chain=0) - for point in points: - trace.record(point) - finally: - trace.close() + with model: + if random_seed is not None: + (random_seed,) = _get_seeds_per_chain(random_seed, 1) + samples: dict = self.sample_dict_fn(draws, model=model, random_seed=random_seed) + spec = self._build_trace_spec(model, samples) + + from collections import OrderedDict + + default_point = model.initial_point() + value_var_names = [var.name for var in model.value_vars] + points = ( + OrderedDict( + ( + name, + np.asarray(samples[name][i]) + if name in samples and len(samples[name]) > i + else np.asarray(spec.test_point.get(name, default_point[name])), + ) + for name in value_var_names + ) + for i in range(draws) + ) + + trace = NDArray( + model=model, + ) + try: + trace.setup(draws=draws, chain=0) + for point in points: + trace.record(point) + finally: + trace.close() multi_trace = MultiTrace([trace]) if not return_inferencedata: return multi_trace else: - return pm.to_inference_data(multi_trace, model=self.model, **kwargs) + return pm.to_inference_data(multi_trace, model=model, **kwargs) @property def ndim(self): @@ -1610,7 +1780,7 @@ def ndim(self): def ddim(self): return sum(self.collect("ddim")) - @node_property + @cached_property def symbolic_random(self): return pt.concatenate(self.collect("symbolic_random"), axis=-1) @@ -1630,7 +1800,7 @@ def all_histograms(self): def any_histograms(self): return any(isinstance(g, pm.approximations.EmpiricalGroup) for g in self.groups) - @node_property + @property def joint_histogram(self): if not self.all_histograms: raise VariationalInferenceError("%s does not consist of all Empirical approximations") diff --git a/pymc/variational/stein.py b/pymc/variational/stein.py index 0534bb6fa4..1bc9360c0f 100644 --- a/pymc/variational/stein.py +++ b/pymc/variational/stein.py @@ -12,13 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import cached_property + import pytensor.tensor as pt from pytensor.graph.replace import graph_replace from pymc.pytensorf import floatX from pymc.util import WithMemoization, locally_cachedmethod -from pymc.variational.opvi import node_property from pymc.variational.test_functions import rbf __all__ = ["Stein"] @@ -38,14 +39,14 @@ def input_joint_matrix(self): else: return self.approx.symbolic_random - @node_property + @cached_property def approx_symbolic_matrices(self): if self.use_histogram: return self.approx.collect("histogram") else: return self.approx.symbolic_randoms - @node_property + @cached_property def dlogp(self): logp = self.logp_norm.sum() grad = pt.grad(logp, self.approx_symbolic_matrices) @@ -55,34 +56,34 @@ def flatten2(tensor): return pt.concatenate(list(map(flatten2, grad)), -1) - @node_property + @cached_property def grad(self): n = floatX(self.input_joint_matrix.shape[0]) temperature = self.temperature svgd_grad = self.density_part_grad / temperature + self.repulsive_part_grad return svgd_grad / n - @node_property + @cached_property def density_part_grad(self): Kxy = self.Kxy dlogpdx = self.dlogp return pt.dot(Kxy, dlogpdx) - @node_property + @cached_property def repulsive_part_grad(self): t = self.approx.symbolic_normalizing_constant dxkxy = self.dxkxy return dxkxy / t - @property + @cached_property def Kxy(self): return self._kernel()[0] - @property + @cached_property def dxkxy(self): return self._kernel()[1] - @node_property + @cached_property def logp_norm(self): sized_symbolic_logp = self.approx.sized_symbolic_logp if self.use_histogram: diff --git a/tests/variational/test_approximations.py b/tests/variational/test_approximations.py index ab30e9bbe3..1bb983ba6b 100644 --- a/tests/variational/test_approximations.py +++ b/tests/variational/test_approximations.py @@ -55,8 +55,9 @@ def test_elbo(): # Create variational gradient tensor mean_field = MeanField(model=model) - with pytensor.config.change_flags(compute_test_value="off"): - elbo = -pm.operators.KL(mean_field)()(10000) + with model: + with pytensor.config.change_flags(compute_test_value="off"): + elbo = -pm.operators.KL(mean_field)()(10000) mean_field.shared_params["mu"].set_value(post_mu) mean_field.shared_params["rho"].set_value(np.log(np.exp(post_sigma) - 1)) @@ -113,9 +114,8 @@ def test_scale_cost_to_minibatch_works(aux_total_size): assert not mean_field_2.scale_cost_to_minibatch mean_field_2.shared_params["mu"].set_value(post_mu) mean_field_2.shared_params["rho"].set_value(np.log(np.exp(post_sigma) - 1)) - - with pytensor.config.change_flags(compute_test_value="off"): - elbo_via_total_size_unscaled = -pm.operators.KL(mean_field_2)()(10000) + with pytensor.config.change_flags(compute_test_value="off"): + elbo_via_total_size_unscaled = -pm.operators.KL(mean_field_2)()(10000) np.testing.assert_allclose( elbo_via_total_size_unscaled.eval(), diff --git a/tests/variational/test_inference.py b/tests/variational/test_inference.py index 10b824179d..ee442e30ab 100644 --- a/tests/variational/test_inference.py +++ b/tests/variational/test_inference.py @@ -41,7 +41,7 @@ def test_fit_with_nans(score): mean = inp * coef pm.Normal("y", mean, 0.1, observed=y) with pytest.raises(FloatingPointError) as e: - advi = pm.fit(100, score=score, obj_optimizer=pm.adam(learning_rate=float("nan"))) + pm.fit(100, score=score, obj_optimizer=pm.adam(learning_rate=float("nan"))) @pytest.fixture(scope="module", params=[True, False], ids=["mini", "full"]) @@ -174,8 +174,9 @@ def fit_kwargs(inference, use_minibatch): return _select[(type(inference), key)] -def test_fit_oo(inference, fit_kwargs, simple_model_data): - trace = inference.fit(**fit_kwargs).sample(10000) +def test_fit_oo(simple_model, inference, fit_kwargs, simple_model_data): + with simple_model: + trace = inference.fit(**fit_kwargs).sample(10000) mu_post = simple_model_data["mu_post"] d = simple_model_data["d"] np.testing.assert_allclose(np.mean(trace.posterior["mu"]), mu_post, rtol=0.05) @@ -202,7 +203,8 @@ def test_fit_start(inference_spec, simple_model): inference = inference_spec(**kw) try: - trace = inference.fit(n=0).sample(10000) + with simple_model: + trace = inference.fit(n=0).sample(10000) except NotImplementedInference as e: pytest.skip(str(e)) @@ -243,10 +245,11 @@ def test_fit_fn_text(method, kwargs, error): pm.fit(10, method=method, **kwargs) -def test_profile(inference): +def test_profile(inference, simple_model): if type(inference) in {SVGD, ASVGD}: pytest.skip("Not Implemented Inference") - inference.run_profiling(n=100).summary() + with simple_model: + inference.run_profiling(n=100).summary() @pytest.fixture(scope="module") @@ -266,64 +269,66 @@ def binomial_model_inference(binomial_model, inference_spec): @pytest.mark.xfail("pytensor.config.warn_float64 == 'raise'", reason="too strict float32") -def test_replacements(binomial_model_inference): +def test_replacements(binomial_model_inference, binomial_model): d = pytensor.shared(1) approx = binomial_model_inference.approx - p = approx.model.p - p_t = p**3 - p_s = approx.sample_node(p_t) - assert not any( - isinstance(n.owner.op, pytensor.tensor.random.basic.BetaRV) - for n in pytensor.graph.ancestors([p_s]) - if n.owner - ), "p should be replaced" - if pytensor.config.compute_test_value != "off": - assert p_s.tag.test_value.shape == p_t.tag.test_value.shape - sampled = [pm.draw(p_s) for _ in range(100)] - assert any(map(operator.ne, sampled[1:], sampled[:-1])) # stochastic - p_z = approx.sample_node(p_t, deterministic=False, size=10) - assert p_z.shape.eval() == (10,) - try: - p_z = approx.sample_node(p_t, deterministic=True, size=10) + with binomial_model: + p = binomial_model.p + p_t = p**3 + p_s = approx.sample_node(p_t) + assert not any( + isinstance(n.owner.op, pytensor.tensor.random.basic.BetaRV) + for n in pytensor.graph.ancestors([p_s]) + if n.owner + ), "p should be replaced" + if pytensor.config.compute_test_value != "off": + assert p_s.tag.test_value.shape == p_t.tag.test_value.shape + sampled = [pm.draw(p_s) for _ in range(100)] + assert any(map(operator.ne, sampled[1:], sampled[:-1])) # stochastic + p_z = approx.sample_node(p_t, deterministic=False, size=10) assert p_z.shape.eval() == (10,) - except opvi.NotImplementedInference: - pass - - try: - p_d = approx.sample_node(p_t, deterministic=True) - sampled = [pm.draw(p_d) for _ in range(100)] + try: + p_z = approx.sample_node(p_t, deterministic=True, size=10) + assert p_z.shape.eval() == (10,) + except opvi.NotImplementedInference: + pass + + try: + p_d = approx.sample_node(p_t, deterministic=True) + sampled = [pm.draw(p_d) for _ in range(100)] + assert all(map(operator.eq, sampled[1:], sampled[:-1])) # deterministic + except opvi.NotImplementedInference: + pass + + p_r = approx.sample_node(p_t, deterministic=d) + d.set_value(1) + sampled = [pm.draw(p_r) for _ in range(100)] assert all(map(operator.eq, sampled[1:], sampled[:-1])) # deterministic - except opvi.NotImplementedInference: - pass - - p_r = approx.sample_node(p_t, deterministic=d) - d.set_value(1) - sampled = [pm.draw(p_r) for _ in range(100)] - assert all(map(operator.eq, sampled[1:], sampled[:-1])) # deterministic - d.set_value(0) - sampled = [pm.draw(p_r) for _ in range(100)] - assert any(map(operator.ne, sampled[1:], sampled[:-1])) # stochastic + d.set_value(0) + sampled = [pm.draw(p_r) for _ in range(100)] + assert any(map(operator.ne, sampled[1:], sampled[:-1])) # stochastic -def test_sample_replacements(binomial_model_inference): +def test_sample_replacements(binomial_model_inference, binomial_model): i = pt.iscalar() i.tag.test_value = 1 approx = binomial_model_inference.approx - p = approx.model.p - p_t = p**3 - p_s = approx.sample_node(p_t, size=100) - if pytensor.config.compute_test_value != "off": - assert p_s.tag.test_value.shape == (100, *p_t.tag.test_value.shape) - sampled = p_s.eval() - assert any(map(operator.ne, sampled[1:], sampled[:-1])) # stochastic - assert sampled.shape[0] == 100 - - p_d = approx.sample_node(p_t, size=i) - sampled = p_d.eval({i: 100}) - assert any(map(operator.ne, sampled[1:], sampled[:-1])) # deterministic - assert sampled.shape[0] == 100 - sampled = p_d.eval({i: 101}) - assert sampled.shape[0] == 101 + with binomial_model: + p = binomial_model.p + p_t = p**3 + p_s = approx.sample_node(p_t, size=100) + if pytensor.config.compute_test_value != "off": + assert p_s.tag.test_value.shape == (100, *p_t.tag.test_value.shape) + sampled = p_s.eval() + assert any(map(operator.ne, sampled[1:], sampled[:-1])) # stochastic + assert sampled.shape[0] == 100 + + p_d = approx.sample_node(p_t, size=i) + sampled = p_d.eval({i: 100}) + assert any(map(operator.ne, sampled[1:], sampled[:-1])) # deterministic + assert sampled.shape[0] == 100 + sampled = p_d.eval({i: 101}) + assert sampled.shape[0] == 101 def test_remove_scan_op(): @@ -353,30 +358,28 @@ def test_var_replacement(): assert advi.sample_node(mean, more_replacements={inp: x_new}).eval().shape == (11,) -def test_clear_cache(): - with pm.Model(): +@pytest.mark.parametrize( + "inference_cls", + [ADVI, FullRankADVI], +) +def test_advi_pickle(inference_cls): + with pm.Model() as model: pm.Normal("n", 0, 1) - inference = ADVI() + inference = inference_cls() inference.fit(n=10) - assert any(len(c) != 0 for c in inference.approx._cache.values()) - inference.approx._cache.clear() - # should not be cleared at this call - assert all(len(c) == 0 for c in inference.approx._cache.values()) - new_a = cloudpickle.loads(cloudpickle.dumps(inference.approx)) - assert not hasattr(new_a, "_cache") - inference_new = pm.KLqp(new_a) + serialized = cloudpickle.dumps(inference.approx) + new_approx = cloudpickle.loads(serialized) + inference_new = pm.KLqp(new_approx) inference_new.fit(n=10) - assert any(len(c) != 0 for c in inference_new.approx._cache.values()) - inference_new.approx._cache.clear() - assert all(len(c) == 0 for c in inference_new.approx._cache.values()) -def test_fit_data(inference, fit_kwargs, simple_model_data): - fitted = inference.fit(**fit_kwargs) - mu_post = simple_model_data["mu_post"] - d = simple_model_data["d"] - np.testing.assert_allclose(fitted.mean_data["mu"].values, mu_post, rtol=0.05) - np.testing.assert_allclose(fitted.std_data["mu"], np.sqrt(1.0 / d), rtol=0.2) +def test_fit_data(inference, fit_kwargs, simple_model_data, simple_model): + with simple_model: + fitted = inference.fit(**fit_kwargs) + mu_post = simple_model_data["mu_post"] + d = simple_model_data["d"] + np.testing.assert_allclose(fitted.mean_data["mu"].values, mu_post, rtol=0.05) + np.testing.assert_allclose(fitted.std_data["mu"], np.sqrt(1.0 / d), rtol=0.2) @pytest.fixture @@ -440,13 +443,13 @@ def test_fit_data_coords(hierarchical_model, hierarchical_model_data): with hierarchical_model: fitted = pm.fit(1) - for data in [fitted.mean_data, fitted.std_data]: - assert set(data.keys()) == {"sigma_group_mu_log__", "sigma_log__", "group_mu", "mu"} - assert data["group_mu"].shape == hierarchical_model_data["group_shape"] - assert list(data["group_mu"].coords.keys()) == list( - hierarchical_model_data["group_coords"].keys() - ) - assert data["mu"].shape == () + for data in [fitted.mean_data, fitted.std_data]: + assert set(data.keys()) == {"sigma_group_mu_log__", "sigma_log__", "group_mu", "mu"} + assert data["group_mu"].shape == hierarchical_model_data["group_shape"] + assert list(data["group_mu"].coords.keys()) == list( + hierarchical_model_data["group_coords"].keys() + ) + assert data["mu"].shape == () def test_multiple_minibatch_variables(): diff --git a/tests/variational/test_opvi.py b/tests/variational/test_opvi.py index 0f40572f72..6ecb68de17 100644 --- a/tests/variational/test_opvi.py +++ b/tests/variational/test_opvi.py @@ -184,48 +184,55 @@ def test_init_groups(three_var_model, raises, grouping): ids=lambda t: ", ".join(f"{k.__name__}: {v[0]}" for k, v in t[1].items()), ) def three_var_groups(request, three_var_model): - kw, grouping = request.param - approxes, groups = zip(*grouping.items()) - groups, gkwargs = zip(*groups) - groups = [ - list(map(ft.partial(getattr, three_var_model), g)) if g is not None else None - for g in groups - ] - inited_groups = [ - a(group=g, model=three_var_model, **gk) for a, g, gk in zip(approxes, groups, gkwargs) - ] + with three_var_model: + kw, grouping = request.param + approxes, groups = zip(*grouping.items()) + groups, gkwargs = zip(*groups) + groups = [ + list(map(ft.partial(getattr, three_var_model), g)) if g is not None else None + for g in groups + ] + inited_groups = [ + a(group=g, model=three_var_model, **gk) for a, g, gk in zip(approxes, groups, gkwargs) + ] return inited_groups @pytest.fixture def three_var_approx(three_var_model, three_var_groups): - approx = opvi.Approximation(three_var_groups, model=three_var_model) + with three_var_model: + approx = opvi.Approximation(three_var_groups, model=three_var_model) return approx @pytest.fixture def three_var_approx_single_group_mf(three_var_model): - return MeanField(model=three_var_model) + with three_var_model: + approx = MeanField(model=three_var_model) + return approx -def test_pickle_approx(three_var_approx): +def test_pickle_approx(three_var_approx, three_var_model): import cloudpickle dump = cloudpickle.dumps(three_var_approx) new = cloudpickle.loads(dump) - assert new.sample(1) + with three_var_model: + assert new.sample(1) -def test_pickle_single_group(three_var_approx_single_group_mf): +def test_pickle_single_group(three_var_approx_single_group_mf, three_var_model): import cloudpickle dump = cloudpickle.dumps(three_var_approx_single_group_mf) new = cloudpickle.loads(dump) - assert new.sample(1) + with three_var_model: + assert new.sample(1) -def test_sample_simple(three_var_approx): - trace = three_var_approx.sample(100, return_inferencedata=False) +def test_sample_simple(three_var_approx, three_var_model): + with three_var_model: + trace = three_var_approx.sample(100, return_inferencedata=False) assert set(trace.varnames) == {"one", "one_log__", "three", "two"} assert len(trace) == 100 assert trace[0]["one"].shape == (10, 2) @@ -246,39 +253,48 @@ def parametric_grouped_approxes(request): def test_logq_mini_1_sample_1_var(parametric_grouped_approxes, three_var_model): cls, kw = parametric_grouped_approxes - approx = cls([three_var_model.one], model=three_var_model, **kw) - logq = approx.logq - logq = approx.set_size_and_deterministic(logq, 1, 0) - logq.eval() + with three_var_model: + approx = cls([three_var_model.one], model=three_var_model, **kw) + logq = approx.logq + logq = approx.set_size_and_deterministic(logq, 1, 0) + logq.eval() def test_logq_mini_2_sample_2_var(parametric_grouped_approxes, three_var_model): cls, kw = parametric_grouped_approxes - approx = cls([three_var_model.one, three_var_model.two], model=three_var_model, **kw) - logq = approx.logq - logq = approx.set_size_and_deterministic(logq, 2, 0) - logq.eval() - - -def test_logq_globals(three_var_approx): - if not three_var_approx.has_logq: - pytest.skip(f"{three_var_approx} does not implement logq") - approx = three_var_approx - logq, symbolic_logq = approx.set_size_and_deterministic( - [approx.logq, approx.symbolic_logq], 1, 0 - ) - e = logq.eval() - es = symbolic_logq.eval() - assert e.shape == () - assert es.shape == (1,) - - logq, symbolic_logq = approx.set_size_and_deterministic( - [approx.logq, approx.symbolic_logq], 2, 0 - ) - e = logq.eval() - es = symbolic_logq.eval() - assert e.shape == () - assert es.shape == (2,) + with three_var_model: + approx = cls([three_var_model.one, three_var_model.two], model=three_var_model, **kw) + logq = approx.logq + logq = approx.set_size_and_deterministic(logq, 2, 0) + logq.eval() + + +def test_logq_globals(three_var_approx, three_var_model): + with three_var_model: + if not three_var_approx.has_logq: + pytest.skip(f"{three_var_approx} does not implement logq") + approx = three_var_approx + logq, symbolic_logq = approx.set_size_and_deterministic( + [approx.logq, approx.symbolic_logq], 1, 0 + ) + e = logq.eval() + es = symbolic_logq.eval() + assert e.shape == () + assert es.shape == (1,) + + logq, symbolic_logq = approx.set_size_and_deterministic( + [approx.logq, approx.symbolic_logq], 2, 0 + ) + e = logq.eval() + es = symbolic_logq.eval() + assert e.shape == () + assert es.shape == (2,) + + +def test_model_property_emits_deprecation(three_var_approx, three_var_model): + with three_var_model: + with pytest.warns(DeprecationWarning, match="`model` field is deprecated"): + _ = three_var_approx.model def test_symbolic_normalizing_constant_no_rvs(): @@ -292,5 +308,32 @@ def test_symbolic_normalizing_constant_no_rvs(): y_hat = pm.Flat("y_hat", observed=obs_batch, total_size=1000) step = pm.ADVI() - - assert_no_rvs(step.approx.symbolic_normalizing_constant) + # Access property within model context + symbolic_normalizing = step.approx.symbolic_normalizing_constant + + # Access the property again to test it doesn't require model context after first access + assert_no_rvs(symbolic_normalizing) + + +def test_sample_additional_vars(three_var_approx, three_var_model): + with pm.Model() as extended_model: + one = pm.HalfNormal("one", size=(10, 2)) + two = pm.Normal("two", size=(10,)) + three = pm.Normal("three", size=(10, 1, 2)) + four = pm.Normal("four", mu=two, sigma=1, size=(10,)) + five = pm.Deterministic("five", four.sum()) + pm.Normal("six", mu=five, sigma=1) + + with extended_model: + idata = three_var_approx.sample(20) + + posterior = idata.posterior + + varnames = set(posterior.data_vars) + assert {"one", "two", "three"}.issubset(varnames) + assert {"four", "five", "six"}.issubset(varnames) + assert posterior.sizes["draw"] == 20 + assert posterior.sizes["chain"] == 1 + assert posterior["four"].shape == (1, 20, 10) + assert posterior["five"].shape == (1, 20) + assert posterior["six"].shape == (1, 20)