@@ -149,7 +149,7 @@ def fn(*args, **kwargs):
149149 return fn
150150
151151
152- def _make_rv_and_resize_shape (
152+ def _make_rv_and_resize_shape_from_dims (
153153 * ,
154154 cls ,
155155 dims : Optional [StrongDims ],
@@ -159,21 +159,23 @@ def _make_rv_and_resize_shape(
159159 ** kwargs ,
160160) -> Tuple [Variable , StrongShape ]:
161161 """Creates the RV, possibly using dims or observed to determine a resize shape (if needed)."""
162- resize_shape = None
162+ resize_shape_from_dims = None
163163 size_or_shape = kwargs .get ("size" ) or kwargs .get ("shape" )
164164
165- # Create the RV without dims or observed information
165+ # Preference is given to size or shape. If not specified, we rely on dims and
166+ # finally, observed, to determine the shape of the variable. Because dims can be
167+ # specified on the fly, we need a two-step process where we first create the RV
168+ # without dims information and then resize it.
169+ if not size_or_shape and observed is not None :
170+ kwargs ["shape" ] = tuple (observed .shape )
171+
172+ # Create the RV without dims information
166173 rv_out = cls .dist (* args , ** kwargs )
167174
168- # Preference is given to size or shape, if not provided we use dims and observed
169- # to resize the variable
170- if not size_or_shape :
171- if dims is not None :
172- resize_shape = shape_from_dims (dims , tuple (rv_out .shape ), model )
173- elif observed is not None :
174- resize_shape = tuple (observed .shape )
175+ if not size_or_shape and dims is not None :
176+ resize_shape_from_dims = shape_from_dims (dims , tuple (rv_out .shape ), model )
175177
176- return rv_out , resize_shape
178+ return rv_out , resize_shape_from_dims
177179
178180
179181class Distribution (metaclass = DistributionMeta ):
@@ -257,16 +259,17 @@ def __new__(
257259 if observed is not None :
258260 observed = convert_observed_data (observed )
259261
260- # Create the RV, possibly taking into consideration dims and observed to
261- # determine its shape
262- rv_out , resize_shape = _make_rv_and_resize_shape (
262+ # Create the RV, without taking `dims` into consideration
263+ rv_out , resize_shape_from_dims = _make_rv_and_resize_shape_from_dims (
263264 cls = cls , dims = dims , model = model , observed = observed , args = args , ** kwargs
264265 )
265266
266- # A shape was specified only through `dims`, or implied by `observed`.
267- if resize_shape :
268- resize_size = find_size (shape = resize_shape , size = None , ndim_supp = cls .rv_op .ndim_supp )
269- rv_out = change_rv_size (rv = rv_out , new_size = resize_size , expand = False )
267+ # Resize variable based on `dims` information
268+ if resize_shape_from_dims :
269+ resize_size_from_dims = find_size (
270+ shape = resize_shape_from_dims , size = None , ndim_supp = cls .rv_op .ndim_supp
271+ )
272+ rv_out = change_rv_size (rv = rv_out , new_size = resize_size_from_dims , expand = False )
270273
271274 rv_out = model .register_rv (
272275 rv_out ,
@@ -452,16 +455,17 @@ def __new__(
452455 if observed is not None :
453456 observed = convert_observed_data (observed )
454457
455- # Create the RV, possibly taking into consideration dims and observed to
456- # determine its shape
457- rv_out , resize_shape = _make_rv_and_resize_shape (
458+ # Create the RV, without taking `dims` into consideration
459+ rv_out , resize_shape_from_dims = _make_rv_and_resize_shape_from_dims (
458460 cls = cls , dims = dims , model = model , observed = observed , args = args , ** kwargs
459461 )
460462
461- # A shape was specified only through `dims`, or implied by `observed`.
462- if resize_shape :
463- resize_size = find_size (shape = resize_shape , size = None , ndim_supp = rv_out .tag .ndim_supp )
464- rv_out = cls .change_size (rv = rv_out , new_size = resize_size , expand = False )
463+ # Resize variable based on `dims` information
464+ if resize_shape_from_dims :
465+ resize_size_from_dims = find_size (
466+ shape = resize_shape_from_dims , size = None , ndim_supp = rv_out .tag .ndim_supp
467+ )
468+ rv_out = cls .change_size (rv = rv_out , new_size = resize_size_from_dims , expand = False )
465469
466470 rv_out = model .register_rv (
467471 rv_out ,
0 commit comments