1313# limitations under the License.
1414import warnings
1515
16- from typing import Any , Optional , Union
16+ from typing import Optional
1717
1818import aesara
1919import aesara .tensor as at
2424from aesara import scan
2525from aesara .graph import FunctionGraph , rewrite_graph
2626from aesara .graph .basic import Node , clone_replace
27- from aesara .raise_op import Assert
2827from aesara .tensor import TensorVariable
2928from aesara .tensor .random .op import RandomVariable
3029from aesara .tensor .rewriting .basic import ShapeFeature , topo_constant_folding
3130
32- from pymc .aesaraf import convert_observed_data , floatX , intX
31+ from pymc .aesaraf import floatX , intX
3332from pymc .distributions import distribution , multivariate
3433from pymc .distributions .continuous import Flat , Normal , get_tau_sigma
3534from pymc .distributions .distribution import (
4039)
4140from pymc .distributions .logprob import ignore_logprob , logp
4241from pymc .distributions .shape_utils import (
43- Dims ,
44- Shape ,
4542 _change_dist_size ,
4643 change_dist_size ,
47- convert_dims ,
44+ get_support_shape_1d ,
4845 to_tuple ,
4946)
50- from pymc .model import modelcontext
5147from pymc .util import check_dist_not_registered
5248
5349__all__ = [
6157]
6258
6359
64- def get_steps (
65- steps : Optional [Union [int , np .ndarray , TensorVariable ]],
66- * ,
67- shape : Optional [Shape ] = None ,
68- dims : Optional [Dims ] = None ,
69- observed : Optional [Any ] = None ,
70- step_shape_offset : int = 0 ,
71- ):
72- """Extract number of steps from shape / dims / observed information
73-
74- Parameters
75- ----------
76- steps:
77- User specified steps for timeseries distribution
78- shape:
79- User specified shape for timeseries distribution
80- dims:
81- User specified dims for timeseries distribution
82- observed:
83- User specified observed data from timeseries distribution
84- step_shape_offset:
85- Difference between last shape dimension and number of steps in timeseries
86- distribution, defaults to 0
87-
88- Returns
89- -------
90- steps
91- Steps, if specified directly by user, or inferred from the last dimension of
92- shape / dims / observed. When two sources of step information are provided,
93- a symbolic Assert is added to ensure they are consistent.
94- """
95- inferred_steps = None
96- if shape is not None :
97- shape = to_tuple (shape )
98- inferred_steps = shape [- 1 ] - step_shape_offset
99-
100- if inferred_steps is None and dims is not None :
101- dims = convert_dims (dims )
102- model = modelcontext (None )
103- inferred_steps = model .dim_lengths [dims [- 1 ]] - step_shape_offset
104-
105- if inferred_steps is None and observed is not None :
106- observed = convert_observed_data (observed )
107- inferred_steps = observed .shape [- 1 ] - step_shape_offset
108-
109- if inferred_steps is None :
110- inferred_steps = steps
111- # If there are two sources of information for the steps, assert they are consistent
112- elif steps is not None :
113- inferred_steps = Assert (msg = "Steps do not match last shape dimension" )(
114- inferred_steps , at .eq (inferred_steps , steps )
115- )
116- return inferred_steps
117-
118-
11960class RandomWalkRV (SymbolicRandomVariable ):
12061 """RandomWalk Variable"""
12162
@@ -132,21 +73,21 @@ class RandomWalk(Distribution):
13273 rv_type = RandomWalkRV
13374
13475 def __new__ (cls , * args , steps = None , ** kwargs ):
135- steps = get_steps (
136- steps = steps ,
76+ steps = get_support_shape_1d (
77+ support_shape = steps ,
13778 shape = None , # Shape will be checked in `cls.dist`
13879 dims = kwargs .get ("dims" , None ),
13980 observed = kwargs .get ("observed" , None ),
140- step_shape_offset = 1 ,
81+ support_shape_offset = 1 ,
14182 )
14283 return super ().__new__ (cls , * args , steps = steps , ** kwargs )
14384
14485 @classmethod
14586 def dist (cls , init_dist , innovation_dist , steps = None , ** kwargs ) -> at .TensorVariable :
146- steps = get_steps (
147- steps = steps ,
87+ steps = get_support_shape_1d (
88+ support_shape = steps ,
14889 shape = kwargs .get ("shape" ),
149- step_shape_offset = 1 ,
90+ support_shape_offset = 1 ,
15091 )
15192 if steps is None :
15293 raise ValueError ("Must specify steps or shape parameter" )
@@ -391,12 +332,12 @@ class AR(Distribution):
391332 def __new__ (cls , name , rho , * args , steps = None , constant = False , ar_order = None , ** kwargs ):
392333 rhos = at .atleast_1d (at .as_tensor_variable (floatX (rho )))
393334 ar_order = cls ._get_ar_order (rhos = rhos , constant = constant , ar_order = ar_order )
394- steps = get_steps (
395- steps = steps ,
335+ steps = get_support_shape_1d (
336+ support_shape = steps ,
396337 shape = None , # Shape will be checked in `cls.dist`
397338 dims = kwargs .get ("dims" , None ),
398339 observed = kwargs .get ("observed" , None ),
399- step_shape_offset = ar_order ,
340+ support_shape_offset = ar_order ,
400341 )
401342 return super ().__new__ (
402343 cls , name , rhos , * args , steps = steps , constant = constant , ar_order = ar_order , ** kwargs
@@ -427,7 +368,9 @@ def dist(
427368 init_dist = kwargs .pop ("init" )
428369
429370 ar_order = cls ._get_ar_order (rhos = rhos , constant = constant , ar_order = ar_order )
430- steps = get_steps (steps = steps , shape = kwargs .get ("shape" , None ), step_shape_offset = ar_order )
371+ steps = get_support_shape_1d (
372+ support_shape = steps , shape = kwargs .get ("shape" , None ), support_shape_offset = ar_order
373+ )
431374 if steps is None :
432375 raise ValueError ("Must specify steps or shape parameter" )
433376 steps = at .as_tensor_variable (intX (steps ), ndim = 0 )
0 commit comments