1313# limitations under the License.
1414from collections .abc import Callable , Sequence
1515from itertools import chain
16+ from typing import cast
1617
18+ import numpy as np
19+
20+ from pytensor .graph import node_rewriter
1721from pytensor .graph .basic import Variable
1822from pytensor .tensor .elemwise import DimShuffle
23+ from pytensor .tensor .random .op import RandomVariable
1924from pytensor .xtensor import as_xtensor
25+ from pytensor .xtensor .basic import XTensorFromTensor , xtensor_from_tensor
2026from pytensor .xtensor .type import XTensorVariable
2127
22- from pymc import modelcontext
23- from pymc .dims .model import with_dims
24- from pymc .distributions import transforms
28+ from pymc import SymbolicRandomVariable , modelcontext
29+ from pymc .dims .distributions .transforms import DimTransform , log_odds_transform , log_transform
2530from pymc .distributions .distribution import _support_point , support_point
2631from pymc .distributions .shape_utils import DimsWithEllipsis , convert_dims_with_ellipsis
27- from pymc .logprob .transforms import Transform
32+ from pymc .logprob .abstract import MeasurableOp , _logprob
33+ from pymc .logprob .rewriting import measurable_ir_rewrites_db
34+ from pymc .logprob .tensor import MeasurableDimShuffle
35+ from pymc .logprob .utils import filter_measurable_variables
2836from pymc .util import UNSET
2937
3038
@@ -36,25 +44,98 @@ def dimshuffle_support_point(ds_op, _, rv):
3644 return ds_op (support_point (rv ))
3745
3846
47+ @_support_point .register (XTensorFromTensor )
48+ def xtensor_from_tensor_support_point (xtensor_op , _ , rv ):
49+ # We remove the xtensor_from_tensor operation, so initial_point doesn't have to do a further lowering
50+ return xtensor_op (support_point (rv ))
51+
52+
53+ class MeasurableXTensorFromTensor (MeasurableOp , XTensorFromTensor ):
54+ __props__ = ("dims" , "core_dims" ) # type: ignore[assignment]
55+
56+ def __init__ (self , dims , core_dims ):
57+ super ().__init__ (dims = dims )
58+ self .core_dims = tuple (core_dims ) if core_dims is not None else None
59+
60+
61+ @node_rewriter ([XTensorFromTensor ])
62+ def find_measurable_xtensor_from_tensor (fgraph , node ) -> list [XTensorVariable ] | None :
63+ if isinstance (node .op , MeasurableXTensorFromTensor ):
64+ return None
65+
66+ xs = filter_measurable_variables (node .inputs )
67+
68+ if not xs :
69+ # Check if we have a transposition instead
70+ # The rewrite that introduces measurable tranpsoses refuses to apply to multivariate RVs
71+ # So we have a chance of inferring the core dims!
72+ [ds ] = node .inputs
73+ ds_node = ds .owner
74+ if not (
75+ ds_node is not None
76+ and isinstance (ds_node .op , DimShuffle )
77+ and ds_node .op .is_transpose
78+ and filter_measurable_variables (ds_node .inputs )
79+ ):
80+ return None
81+ [x ] = ds_node .inputs
82+ if not (
83+ x .owner is not None and isinstance (x .owner .op , RandomVariable | SymbolicRandomVariable )
84+ ):
85+ return None
86+
87+ measurable_x = MeasurableDimShuffle (** ds_node .op ._props_dict ())(x ) # type: ignore[attr-defined]
88+
89+ ndim_supp = x .owner .op .ndim_supp
90+ if ndim_supp :
91+ inverse_transpose = np .argsort (ds_node .op .shuffle )
92+ dims = node .op .dims
93+ dims_before_transpose = tuple (dims [i ] for i in inverse_transpose )
94+ core_dims = dims_before_transpose [- ndim_supp :]
95+ else :
96+ core_dims = ()
97+
98+ new_out = MeasurableXTensorFromTensor (dims = node .op .dims , core_dims = core_dims )(measurable_x )
99+ else :
100+ # If this happens we know there's no measurable transpose in between and we can
101+ # safely infer the core_dims positionally when the inner logp is returned
102+ new_out = MeasurableXTensorFromTensor (dims = node .op .dims , core_dims = None )(* node .inputs )
103+ return [cast (XTensorVariable , new_out )]
104+
105+
106+ @_logprob .register (MeasurableXTensorFromTensor )
107+ def measurable_xtensor_from_tensor (op , values , rv , ** kwargs ):
108+ rv_logp = _logprob (rv .owner .op , tuple (v .values for v in values ), * rv .owner .inputs , ** kwargs )
109+ if op .core_dims is None :
110+ # The core_dims of the inner rv are on the right
111+ dims = op .dims [: rv_logp .ndim ]
112+ else :
113+ # We inferred where the core_dims are!
114+ dims = [d for d in op .dims if d not in op .core_dims ]
115+ return xtensor_from_tensor (rv_logp , dims = dims )
116+
117+
118+ measurable_ir_rewrites_db .register (
119+ "measurable_xtensor_from_tensor" , find_measurable_xtensor_from_tensor , "basic" , "xtensor"
120+ )
121+
122+
39123class DimDistribution :
40124 """Base class for PyMC distribution that wrap pytensor.xtensor.random operations, and follow xarray-like semantics."""
41125
42126 xrv_op : Callable
43- default_transform : Transform | None = None
127+ default_transform : DimTransform | None = None
44128
45129 @staticmethod
46130 def _as_xtensor (x ):
47131 try :
48132 return as_xtensor (x )
49133 except TypeError :
50- try :
51- return with_dims (x )
52- except ValueError :
53- raise ValueError (
54- f"Variable { x } must have dims associated with it.\n "
55- "To avoid subtle bugs, PyMC does not make any assumptions about the dims of parameters.\n "
56- "Use `as_xtensor` with the `dims` keyword argument to specify the dims explicitly."
57- )
134+ raise ValueError (
135+ f"Variable { x } must have dims associated with it.\n "
136+ "To avoid subtle bugs, PyMC does not make any assumptions about the dims of parameters.\n "
137+ "Use `as_xtensor` with the `dims` keyword argument to specify the dims explicitly."
138+ )
58139
59140 def __new__ (
60141 cls ,
@@ -119,10 +200,22 @@ def __new__(
119200 else :
120201 # Align observed dims with those of the RV
121202 # TODO: If this fails give a more informative error message
122- observed = observed .transpose (* rv_dims ).values
203+ observed = observed .transpose (* rv_dims )
204+
205+ # Check user didn't pass regular transforms
206+ if transform not in (UNSET , None ):
207+ if not isinstance (transform , DimTransform ):
208+ raise TypeError (
209+ f"Transform must be a DimTransform, form pymc.dims.transforms, but got { type (transform )} ."
210+ )
211+ if default_transform not in (UNSET , None ):
212+ if not isinstance (default_transform , DimTransform ):
213+ raise TypeError (
214+ f"default_transform must be a DimTransform, from pymc.dims.transforms, but got { type (default_transform )} ."
215+ )
123216
124217 rv = model .register_rv (
125- rv . values ,
218+ rv ,
126219 name = name ,
127220 observed = observed ,
128221 total_size = total_size ,
@@ -182,10 +275,10 @@ def dist(self, *args, core_dims: str | Sequence[str] | None = None, **kwargs):
182275class PositiveDimDistribution (DimDistribution ):
183276 """Base class for positive continuous distributions."""
184277
185- default_transform = transforms . log
278+ default_transform = log_transform
186279
187280
188281class UnitDimDistribution (DimDistribution ):
189282 """Base class for unit-valued distributions."""
190283
191- default_transform = transforms . logodds
284+ default_transform = log_odds_transform
0 commit comments