|
41 | 41 | from pytensor.graph.rewriting.basic import node_rewriter |
42 | 42 | from pytensor.tensor import TensorVariable |
43 | 43 | from pytensor.tensor.basic import Join, MakeVector |
44 | | -from pytensor.tensor.elemwise import DimShuffle |
| 44 | +from pytensor.tensor.elemwise import DimShuffle, Elemwise |
45 | 45 | from pytensor.tensor.random.op import RandomVariable |
46 | 46 | from pytensor.tensor.random.rewriting import ( |
47 | 47 | local_dimshuffle_rv_lift, |
48 | 48 | ) |
49 | 49 |
|
50 | | -from pymc.logprob.abstract import MeasurableOp, _logprob, _logprob_helper, promised_valued_rv |
| 50 | +from pymc.logprob.abstract import ( |
| 51 | + MeasurableOp, |
| 52 | + ValuedRV, |
| 53 | + _logprob, |
| 54 | + _logprob_helper, |
| 55 | + promised_valued_rv, |
| 56 | +) |
51 | 57 | from pymc.logprob.rewriting import ( |
52 | 58 | assume_valued_outputs, |
53 | 59 | early_measurable_ir_rewrites_db, |
|
57 | 63 | from pymc.logprob.utils import ( |
58 | 64 | check_potential_measurability, |
59 | 65 | filter_measurable_variables, |
| 66 | + get_related_valued_nodes, |
60 | 67 | replace_rvs_by_values, |
61 | 68 | ) |
62 | 69 | from pymc.pytensorf import constant_fold |
@@ -183,6 +190,9 @@ class MeasurableDimShuffle(MeasurableOp, DimShuffle): |
183 | 190 | # find it locally and fails when a new `Op` is initialized |
184 | 191 | c_func_file = str(DimShuffle.get_path(Path(DimShuffle.c_func_file))) # type: ignore[arg-type] |
185 | 192 |
|
| 193 | + def __str__(self): |
| 194 | + return f"Measurable{super().__str__()}" |
| 195 | + |
186 | 196 |
|
187 | 197 | @_logprob.register(MeasurableDimShuffle) |
188 | 198 | def logprob_dimshuffle(op: MeasurableDimShuffle, values, base_var, **kwargs): |
@@ -215,29 +225,66 @@ def logprob_dimshuffle(op: MeasurableDimShuffle, values, base_var, **kwargs): |
215 | 225 | return raw_logp.dimshuffle(redo_ds) |
216 | 226 |
|
217 | 227 |
|
| 228 | +def _elemwise_univariate_chain(fgraph, node) -> bool: |
| 229 | + # Check whether only Elemwise operations connect a base univariate RV to the valued node through var. |
| 230 | + from pymc.distributions.distribution import SymbolicRandomVariable |
| 231 | + from pymc.logprob.transforms import MeasurableTransform |
| 232 | + |
| 233 | + [inp] = node.inputs |
| 234 | + [out] = node.outputs |
| 235 | + |
| 236 | + def elemwise_root(var: TensorVariable) -> TensorVariable | None: |
| 237 | + if isinstance(var.owner.op, RandomVariable | SymbolicRandomVariable): |
| 238 | + return var |
| 239 | + elif isinstance(var.owner.op, MeasurableTransform): |
| 240 | + return elemwise_root(var.owner.inputs[var.owner.op.measurable_input_idx]) |
| 241 | + else: |
| 242 | + return None |
| 243 | + |
| 244 | + # Check that the root is a univariate distribution linked by only elemwise operations |
| 245 | + root = elemwise_root(inp) |
| 246 | + if root is None: |
| 247 | + return False |
| 248 | + elif root.owner.op.ndim_supp != 0: |
| 249 | + # This is still fine if the variable is directly valued |
| 250 | + return any(get_related_valued_nodes(fgraph, node)) |
| 251 | + |
| 252 | + def elemwise_leaf(var: TensorVariable, clients=fgraph.clients) -> bool: |
| 253 | + var_clients = clients[var] |
| 254 | + if len(var_clients) != 1: |
| 255 | + return False |
| 256 | + [(client, _)] = var_clients |
| 257 | + if isinstance(client.op, ValuedRV): |
| 258 | + return True |
| 259 | + elif isinstance(client.op, Elemwise) and len(client.outputs) == 1: |
| 260 | + return elemwise_leaf(client.outputs[0]) |
| 261 | + else: |
| 262 | + return False |
| 263 | + |
| 264 | + # Check that the path to the valued node consists only of elemwise operations |
| 265 | + return elemwise_leaf(out) |
| 266 | + |
| 267 | + |
218 | 268 | @node_rewriter([DimShuffle]) |
219 | 269 | def find_measurable_dimshuffles(fgraph, node) -> list[TensorVariable] | None: |
220 | 270 | r"""Find `Dimshuffle`\s for which a `logprob` can be computed.""" |
221 | | - from pymc.distributions.distribution import SymbolicRandomVariable |
222 | | - |
223 | 271 | if isinstance(node.op, MeasurableOp): |
224 | 272 | return None |
225 | 273 |
|
226 | 274 | if not filter_measurable_variables(node.inputs): |
227 | 275 | return None |
228 | 276 |
|
229 | | - base_var = node.inputs[0] |
| 277 | + # In cases where DimShuffle transposes dimensions, we only apply this rewrite when only Elemwise |
| 278 | + # operations separate it from the valued node. Further transformations likely need to know where |
| 279 | + # the support axes are for a correct implementation (and thus assume they are the rightmost axes). |
| 280 | + # TODO: When we include the support axis as meta information in each intermediate MeasurableVariable, |
| 281 | + # we can lift this restriction (see https://github.com/pymc-devs/pymc/issues/6360) |
| 282 | + if tuple(node.op.shuffle) != tuple(sorted(node.op.shuffle)) and not _elemwise_univariate_chain( |
| 283 | + fgraph, node |
| 284 | + ): |
| 285 | + return None |
230 | 286 |
|
231 | | - # We can only apply this rewrite directly to `RandomVariable`s, as those are |
232 | | - # the only `Op`s for which we always know the support axis. Other measurable |
233 | | - # variables can have arbitrary support axes (e.g., if they contain separate |
234 | | - # `MeasurableDimShuffle`s). Most measurable variables with `DimShuffle`s |
235 | | - # should still be supported as long as the `DimShuffle`s can be merged/ |
236 | | - # lifted towards the base RandomVariable. |
237 | | - # TODO: If we include the support axis as meta information in each |
238 | | - # intermediate MeasurableVariable, we can lift this restriction. |
239 | | - if not isinstance(base_var.owner.op, RandomVariable | SymbolicRandomVariable): |
240 | | - return None # pragma: no cover |
| 287 | + base_var = node.inputs[0] |
241 | 288 |
|
242 | 289 | measurable_dimshuffle = MeasurableDimShuffle(node.op.input_broadcastable, node.op.new_order)( |
243 | 290 | base_var |
|
0 commit comments