|
21 | 21 |
|
22 | 22 | from aeppl.abstract import _get_measurable_outputs |
23 | 23 | from aeppl.logprob import _logprob |
24 | | -from aesara.graph import FunctionGraph, rewrite_graph |
25 | 24 | from aesara.graph.basic import Node, clone_replace |
26 | 25 | from aesara.raise_op import Assert |
27 | 26 | from aesara.tensor import TensorVariable |
28 | 27 | from aesara.tensor.random.op import RandomVariable |
29 | | -from aesara.tensor.rewriting.basic import ShapeFeature, topo_constant_folding |
30 | 28 |
|
31 | | -from pymc.aesaraf import convert_observed_data, floatX, intX |
| 29 | +from pymc.aesaraf import constant_fold, convert_observed_data, floatX, intX |
32 | 30 | from pymc.distributions import distribution, multivariate |
33 | 31 | from pymc.distributions.continuous import Flat, Normal, get_tau_sigma |
34 | 32 | from pymc.distributions.distribution import ( |
|
46 | 44 | convert_dims, |
47 | 45 | to_tuple, |
48 | 46 | ) |
| 47 | +from pymc.exceptions import NotConstantValueError |
49 | 48 | from pymc.model import modelcontext |
50 | 49 | from pymc.util import check_dist_not_registered |
51 | 50 |
|
@@ -472,14 +471,9 @@ def _get_ar_order(cls, rhos: TensorVariable, ar_order: Optional[int], constant: |
472 | 471 | If inferred ar_order cannot be inferred from rhos or if it is less than 1 |
473 | 472 | """ |
474 | 473 | if ar_order is None: |
475 | | - shape_fg = FunctionGraph( |
476 | | - outputs=[rhos.shape[-1]], |
477 | | - features=[ShapeFeature()], |
478 | | - clone=True, |
479 | | - ) |
480 | | - (folded_shape,) = rewrite_graph(shape_fg, custom_rewrite=topo_constant_folding).outputs |
481 | | - folded_shape = getattr(folded_shape, "data", None) |
482 | | - if folded_shape is None: |
| 474 | + try: |
| 475 | + (folded_shape,) = constant_fold((rhos.shape[-1],)) |
| 476 | + except NotConstantValueError: |
483 | 477 | raise ValueError( |
484 | 478 | "Could not infer ar_order from last dimension of rho. Pass it " |
485 | 479 | "explictily or make sure rho have a static shape" |
|
0 commit comments