|
21 | 21 |
|
22 | 22 | from aeppl.abstract import _get_measurable_outputs |
23 | 23 | from aeppl.logprob import _logprob |
| 24 | +from aesara.graph import FunctionGraph, rewrite_graph |
24 | 25 | from aesara.graph.basic import Node, clone_replace |
25 | 26 | from aesara.raise_op import Assert |
26 | 27 | from aesara.tensor import TensorVariable |
27 | 28 | from aesara.tensor.random.op import RandomVariable |
| 29 | +from aesara.tensor.rewriting.basic import ShapeFeature, topo_constant_folding |
28 | 30 |
|
29 | | -from pymc.aesaraf import constant_fold, convert_observed_data, floatX, intX |
| 31 | +from pymc.aesaraf import convert_observed_data, floatX, intX |
30 | 32 | from pymc.distributions import distribution, multivariate |
31 | 33 | from pymc.distributions.continuous import Flat, Normal, get_tau_sigma |
32 | 34 | from pymc.distributions.distribution import ( |
|
44 | 46 | convert_dims, |
45 | 47 | to_tuple, |
46 | 48 | ) |
47 | | -from pymc.exceptions import NotConstantValueError |
48 | 49 | from pymc.model import modelcontext |
49 | 50 | from pymc.util import check_dist_not_registered |
50 | 51 |
|
@@ -471,9 +472,14 @@ def _get_ar_order(cls, rhos: TensorVariable, ar_order: Optional[int], constant: |
471 | 472 | If inferred ar_order cannot be inferred from rhos or if it is less than 1 |
472 | 473 | """ |
473 | 474 | if ar_order is None: |
474 | | - try: |
475 | | - (folded_shape,) = constant_fold((rhos.shape[-1],)) |
476 | | - except NotConstantValueError: |
| 475 | + shape_fg = FunctionGraph( |
| 476 | + outputs=[rhos.shape[-1]], |
| 477 | + features=[ShapeFeature()], |
| 478 | + clone=True, |
| 479 | + ) |
| 480 | + (folded_shape,) = rewrite_graph(shape_fg, custom_rewrite=topo_constant_folding).outputs |
| 481 | + folded_shape = getattr(folded_shape, "data", None) |
| 482 | + if folded_shape is None: |
477 | 483 | raise ValueError( |
478 | 484 | "Could not infer ar_order from last dimension of rho. Pass it " |
479 | 485 | "explictily or make sure rho have a static shape" |
|
0 commit comments