|
33 | 33 | from aesara.tensor.var import TensorVariable |
34 | 34 | from typing_extensions import TypeAlias |
35 | 35 |
|
36 | | -from pymc.aesaraf import change_rv_size |
| 36 | +from pymc.aesaraf import change_rv_size, convert_observed_data |
37 | 37 | from pymc.distributions.shape_utils import ( |
38 | 38 | Dims, |
39 | 39 | Shape, |
40 | 40 | Size, |
| 41 | + StrongDims, |
41 | 42 | StrongShape, |
42 | 43 | convert_dims, |
43 | 44 | convert_shape, |
44 | 45 | convert_size, |
45 | 46 | find_size, |
46 | | - resize_from_dims, |
47 | | - resize_from_observed, |
| 47 | + shape_from_dims, |
48 | 48 | ) |
49 | 49 | from pymc.printing import str_for_dist, str_for_symbolic_dist |
50 | 50 | from pymc.util import UNSET |
@@ -152,29 +152,28 @@ def fn(*args, **kwargs): |
152 | 152 | def _make_rv_and_resize_shape( |
153 | 153 | *, |
154 | 154 | cls, |
155 | | - dims: Optional[Dims], |
| 155 | + dims: Optional[StrongDims], |
156 | 156 | model, |
157 | 157 | observed, |
158 | 158 | args, |
159 | 159 | **kwargs, |
160 | | -) -> Tuple[Variable, Optional[Dims], Optional[Union[np.ndarray, Variable]], StrongShape]: |
161 | | - """Creates the RV and processes dims or observed to determine a resize shape.""" |
162 | | - # Create the RV without dims information, because that's not something tracked at the Aesara level. |
163 | | - # If necessary we'll later replicate to a different size implied by already known dims. |
164 | | - rv_out = cls.dist(*args, **kwargs) |
165 | | - ndim_actual = rv_out.ndim |
| 160 | +) -> Tuple[Variable, StrongShape]: |
| 161 | + """Creates the RV, possibly using dims or observed to determine a resize shape (if needed).""" |
166 | 162 | resize_shape = None |
| 163 | + size_or_shape = kwargs.get("size") or kwargs.get("shape") |
| 164 | + |
| 165 | + # Create the RV without dims or observed information |
| 166 | + rv_out = cls.dist(*args, **kwargs) |
167 | 167 |
|
168 | | - # # `dims` are only available with this API, because `.dist()` can be used |
169 | | - # # without a modelcontext and dims are not tracked at the Aesara level. |
170 | | - dims = convert_dims(dims) |
171 | | - dims_can_resize = kwargs.get("shape", None) is None and kwargs.get("size", None) is None |
172 | | - if dims is not None: |
173 | | - if dims_can_resize: |
174 | | - resize_shape, dims = resize_from_dims(dims, ndim_actual, model) |
175 | | - elif observed is not None: |
176 | | - resize_shape, observed = resize_from_observed(observed, ndim_actual) |
177 | | - return rv_out, dims, observed, resize_shape |
| 168 | + # Preference is given to size or shape, if not provided we use dims and observed |
| 169 | + # to resize the variable |
| 170 | + if not size_or_shape: |
| 171 | + if dims is not None: |
| 172 | + resize_shape = shape_from_dims(dims, tuple(rv_out.shape), model) |
| 173 | + elif observed is not None: |
| 174 | + resize_shape = tuple(observed.shape) |
| 175 | + |
| 176 | + return rv_out, resize_shape |
178 | 177 |
|
179 | 178 |
|
180 | 179 | class Distribution(metaclass=DistributionMeta): |
@@ -254,15 +253,20 @@ def __new__( |
254 | 253 | if not isinstance(name, string_types): |
255 | 254 | raise TypeError(f"Name needs to be a string but got: {name}") |
256 | 255 |
|
257 | | - # Create the RV and process dims and observed to determine |
258 | | - # a shape by which the created RV may need to be resized. |
259 | | - rv_out, dims, observed, resize_shape = _make_rv_and_resize_shape( |
| 256 | + dims = convert_dims(dims) |
| 257 | + if observed is not None: |
| 258 | + observed = convert_observed_data(observed) |
| 259 | + |
| 260 | + # Create the RV, possibly taking into consideration dims and observed to |
| 261 | + # determine its shape |
| 262 | + rv_out, resize_shape = _make_rv_and_resize_shape( |
260 | 263 | cls=cls, dims=dims, model=model, observed=observed, args=args, **kwargs |
261 | 264 | ) |
262 | 265 |
|
| 266 | + # A shape was specified only through `dims`, or implied by `observed`. |
263 | 267 | if resize_shape: |
264 | | - # A batch size was specified through `dims`, or implied by `observed`. |
265 | | - rv_out = change_rv_size(rv=rv_out, new_size=resize_shape, expand=True) |
| 268 | + resize_size = find_size(shape=resize_shape, size=None, ndim_supp=cls.rv_op.ndim_supp) |
| 269 | + rv_out = change_rv_size(rv=rv_out, new_size=resize_size, expand=False) |
266 | 270 |
|
267 | 271 | rv_out = model.register_rv( |
268 | 272 | rv_out, |
@@ -336,11 +340,7 @@ def dist( |
336 | 340 | shape = convert_shape(shape) |
337 | 341 | size = convert_size(size) |
338 | 342 |
|
339 | | - create_size, ndim_expected, ndim_batch, ndim_supp = find_size( |
340 | | - shape=shape, size=size, ndim_supp=cls.rv_op.ndim_supp |
341 | | - ) |
342 | | - # Create the RV with a `size` right away. |
343 | | - # This is not necessarily the final result. |
| 343 | + create_size = find_size(shape=shape, size=size, ndim_supp=cls.rv_op.ndim_supp) |
344 | 344 | rv_out = cls.rv_op(*dist_params, size=create_size, **kwargs) |
345 | 345 |
|
346 | 346 | rv_out.logp = _make_nice_attr_error("rv.logp(x)", "pm.logp(rv, x)") |
@@ -448,19 +448,20 @@ def __new__( |
448 | 448 | if not isinstance(name, string_types): |
449 | 449 | raise TypeError(f"Name needs to be a string but got: {name}") |
450 | 450 |
|
451 | | - # Create the RV and process dims and observed to determine |
452 | | - # a shape by which the created RV may need to be resized. |
453 | | - rv_out, dims, observed, resize_shape = _make_rv_and_resize_shape( |
| 451 | + dims = convert_dims(dims) |
| 452 | + if observed is not None: |
| 453 | + observed = convert_observed_data(observed) |
| 454 | + |
| 455 | + # Create the RV, possibly taking into consideration dims and observed to |
| 456 | + # determine its shape |
| 457 | + rv_out, resize_shape = _make_rv_and_resize_shape( |
454 | 458 | cls=cls, dims=dims, model=model, observed=observed, args=args, **kwargs |
455 | 459 | ) |
456 | 460 |
|
| 461 | + # A shape was specified only through `dims`, or implied by `observed`. |
457 | 462 | if resize_shape: |
458 | | - # A batch size was specified through `dims`, or implied by `observed`. |
459 | | - rv_out = cls.change_size( |
460 | | - rv=rv_out, |
461 | | - new_size=resize_shape, |
462 | | - expand=True, |
463 | | - ) |
| 463 | + resize_size = find_size(shape=resize_shape, size=None, ndim_supp=rv_out.tag.ndim_supp) |
| 464 | + rv_out = cls.change_size(rv=rv_out, new_size=resize_size, expand=False) |
464 | 465 |
|
465 | 466 | rv_out = model.register_rv( |
466 | 467 | rv_out, |
@@ -529,18 +530,17 @@ def dist( |
529 | 530 | shape = convert_shape(shape) |
530 | 531 | size = convert_size(size) |
531 | 532 |
|
532 | | - create_size, ndim_expected, ndim_batch, ndim_supp = find_size( |
533 | | - shape=shape, size=size, ndim_supp=cls.ndim_supp(*dist_params) |
534 | | - ) |
535 | | - # Create the RV with a `size` right away. |
536 | | - # This is not necessarily the final result. |
537 | | - graph = cls.rv_op(*dist_params, size=create_size, **kwargs) |
| 533 | + ndim_supp = cls.ndim_supp(*dist_params) |
| 534 | + create_size = find_size(shape=shape, size=size, ndim_supp=ndim_supp) |
| 535 | + rv_out = cls.rv_op(*dist_params, size=create_size, **kwargs) |
| 536 | + # This is needed for resizing from dims in `__new__` |
| 537 | + rv_out.tag.ndim_supp = ndim_supp |
538 | 538 |
|
539 | 539 | # TODO: Create new attr error stating that these are not available for DerivedDistribution |
540 | 540 | # rv_out.logp = _make_nice_attr_error("rv.logp(x)", "pm.logp(rv, x)") |
541 | 541 | # rv_out.logcdf = _make_nice_attr_error("rv.logcdf(x)", "pm.logcdf(rv, x)") |
542 | 542 | # rv_out.random = _make_nice_attr_error("rv.random()", "rv.eval()") |
543 | | - return graph |
| 543 | + return rv_out |
544 | 544 |
|
545 | 545 |
|
546 | 546 | @singledispatch |
|
0 commit comments