@@ -140,9 +140,23 @@ def handle_dims(x: pt.TensorLike, dims: Dims, desired_dims: Dims) -> pt.TensorVa
140140
141141 Doesn't check for validity of the dims
142142
143+ Parameters
144+ ----------
145+ x : pt.TensorLike
146+ The tensor to align.
147+ dims : Dims
148+ The current dimensions of the tensor.
149+ desired_dims : Dims
150+ The desired dimensions of the tensor.
151+
152+ Returns
153+ -------
154+ pt.TensorVariable
155+ The aligned tensor.
156+
143157 Examples
144158 --------
145- 1D to 2D with new dim
159+ Handle transpose 1D to 2D with new dimension.
146160
147161 .. code-block:: python
148162
@@ -179,10 +193,40 @@ def handle_dims(x: pt.TensorLike, dims: Dims, desired_dims: Dims) -> pt.TensorVa
179193
180194
181195DimHandler = Callable [[pt .TensorLike , Dims ], pt .TensorLike ]
196+ """A function that takes a tensor and its current dims and makes it broadcastable to the desired dims."""
182197
183198
184199def create_dim_handler (desired_dims : Dims ) -> DimHandler :
185- """Wrap the `handle_dims` function to act like the previous `create_dim_handler` function."""
200+ """Wrap the :func:`handle_dims` function to always use the same desired_dims.
201+
202+ Parameters
203+ ----------
204+ desired_dims : Dims
205+ The desired dimensions to align to.
206+
207+ Returns
208+ -------
209+ DimHandler
210+ A function that takes a tensor and its current dims and aligns it to
211+ the desired dims.
212+
213+
214+ Examples
215+ --------
216+ Create a dim handler to align to ("channel", "group").
217+
218+ .. code-block:: python
219+
220+ import numpy as np
221+
222+ from pymc_extras.prior import create_dim_handler
223+
224+ dim_handler = create_dim_handler(("channel", "group"))
225+
226+ result = dim_handler(np.array([1, 2, 3]), dims="channel")
227+
228+
229+ """
186230
187231 def func (x : pt .TensorLike , dims : Dims ) -> pt .TensorVariable :
188232 return handle_dims (x , dims , desired_dims )
@@ -272,9 +316,44 @@ def _get_pymc_parameters(distribution: pm.Distribution) -> set[str]:
272316
273317@runtime_checkable
274318class VariableFactory (Protocol ):
275- """Protocol for something that works like a Prior class."""
319+ '''Protocol for something that works like a Prior class.
320+
321+ Sample with :func:`sample_prior`.
322+
323+ Examples
324+ --------
325+ Create a custom variable factory.
326+
327+ .. code-block:: python
328+
329+ import pymc as pm
330+
331+ import pytensor.tensor as pt
332+
333+ from pymc_extras.prior import sample_prior, VariableFactory
334+
335+
336+ class PowerSumDistribution:
337+ """Create a distribution that is the sum of powers of a base distribution."""
338+ def __init__(self, distribution: VariableFactory, n: int):
339+ self.distribution = distribution
340+ self.n = n
341+
342+ @property
343+ def dims(self):
344+ return self.distribution.dims
345+
346+ def create_variable(self, name: str) -> "TensorVariable":
347+ raw = self.distribution.create_variable(f"{name}_raw")
348+ return pm.Deterministic(name, pt.sum([raw ** n for n in range(1, self.n + 1)], axis=0), dims=self.dims,)
349+
350+ cubic = PowerSumDistribution(Prior("Normal"), n=3)
351+ samples = sample_prior(cubic)
352+
353+ '''
276354
277355 dims : tuple [str , ...]
356+ """The dimensions of the variable to create."""
278357
279358 def create_variable (self , name : str ) -> pt .TensorVariable :
280359 """Create a TensorVariable."""
@@ -387,6 +466,80 @@ class Prior:
387466 be registered with `register_tensor_transform` function or
388467 be available in either `pytensor.tensor` or `pymc.math`.
389468
469+ Examples
470+ --------
471+ Create a normal prior.
472+
473+ .. code-block:: python
474+
475+ from pymc_extras.prior import Prior
476+
477+ normal = Prior("Normal")
478+
479+ Create a hierarchical normal prior by using distributions for the parameters
480+ and specifying the dims.
481+
482+ .. code-block:: python
483+
484+ hierarchical_normal = Prior(
485+ "Normal",
486+ mu=Prior("Normal"),
487+ sigma=Prior("HalfNormal"),
488+ dims="channel",
489+ )
490+
491+ Create a non-centered hierarchical normal prior with the `centered` parameter.
492+
493+ .. code-block:: python
494+
495+ non_centered_hierarchical_normal = Prior(
496+ "Normal",
497+ mu=Prior("Normal"),
498+ sigma=Prior("HalfNormal"),
499+ dims="channel",
500+ # Only change needed to make it non-centered
501+ centered=False,
502+ )
503+
504+ Create a hierarchical beta prior by using Beta distribution, distributions for
505+ the parameters, and specifying the dims.
506+
507+ .. code-block:: python
508+
509+ hierarchical_beta = Prior(
510+ "Beta",
511+ alpha=Prior("HalfNormal"),
512+ beta=Prior("HalfNormal"),
513+ dims="channel",
514+ )
515+
516+ Create a transformed hierarchical normal prior by using the `transform`
517+ parameter. Here the "sigmoid" transformation comes from `pm.math`.
518+
519+ .. code-block:: python
520+
521+ transformed_hierarchical_normal = Prior(
522+ "Normal",
523+ mu=Prior("Normal"),
524+ sigma=Prior("HalfNormal"),
525+ transform="sigmoid",
526+ dims="channel",
527+ )
528+
529+ Create a prior with a custom transform function by registering it with
530+ :func:`register_tensor_transform`.
531+
532+ .. code-block:: python
533+
534+ from pymc_extras.prior import register_tensor_transform
535+
536+ def custom_transform(x):
537+ return x ** 2
538+
539+ register_tensor_transform("square", custom_transform)
540+
541+ custom_distribution = Prior("Normal", transform="square")
542+
390543 """
391544
392545 # Taken from https://en.wikipedia.org/wiki/Location%E2%80%93scale_family
@@ -395,9 +548,13 @@ class Prior:
395548 "StudentT" : {"mu" : 0 , "sigma" : 1 },
396549 "ZeroSumNormal" : {"sigma" : 1 },
397550 }
551+ """Available non-centered distributions and their default parameters."""
398552
399553 pymc_distribution : type [pm .Distribution ]
554+ """The PyMC distribution class."""
555+
400556 pytensor_transform : Callable [[pt .TensorLike ], pt .TensorLike ] | None
557+ """The PyTensor transform function."""
401558
402559 @validate_call
403560 def __init__ (
@@ -1323,9 +1480,33 @@ def create_likelihood_variable(
13231480
13241481
13251482class Scaled :
1326- """Scaled distribution for numerical stability."""
1483+ """Scaled distribution for numerical stability.
1484+
1485+ This is the same as multiplying the variable by a constant factor.
1486+
1487+ Parameters
1488+ ----------
1489+ dist : Prior
1490+ The prior distribution to scale.
1491+ factor : pt.TensorLike
1492+ The scaling factor. This will have to be broadcastable to the
1493+ dimensions of the distribution.
1494+
1495+ Examples
1496+ --------
1497+ Create a scaled normal distribution.
1498+
1499+ .. code-block:: python
1500+
1501+ from pymc_extras.prior import Prior, Scaled
1502+
1503+ normal = Prior("Normal", mu=0, sigma=1)
1504+ # Same as Normal(mu=0, sigma=10)
1505+ scaled_normal = Scaled(normal, factor=10)
1506+
1507+ """
13271508
1328- def __init__ (self , dist : Prior , factor : float | pt .TensorVariable ) -> None :
1509+ def __init__ (self , dist : Prior , factor : pt .TensorLike ) -> None :
13291510 self .dist = dist
13301511 self .factor = factor
13311512
0 commit comments