@@ -211,9 +211,7 @@ class TestShapeDimsSize:
211211 [
212212 "implicit" ,
213213 "shape" ,
214- "shape..." ,
215214 "dims" ,
216- "dims..." ,
217215 "size" ,
218216 ],
219217 )
@@ -249,65 +247,36 @@ def test_param_and_batch_shape_combos(
249247 if parametrization == "shape" :
250248 rv = pm .Normal ("rv" , mu = mu , shape = batch_shape + param_shape )
251249 assert rv .eval ().shape == expected_shape
252- elif parametrization == "shape..." :
253- rv = pm .Normal ("rv" , mu = mu , shape = (* batch_shape , ...))
254- assert rv .eval ().shape == batch_shape + param_shape
255250 elif parametrization == "dims" :
256251 rv = pm .Normal ("rv" , mu = mu , dims = batch_dims + param_dims )
257252 assert rv .eval ().shape == expected_shape
258- elif parametrization == "dims..." :
259- rv = pm .Normal ("rv" , mu = mu , dims = (* batch_dims , ...))
260- n_size = len (batch_shape )
261- n_implied = len (param_shape )
262- ndim = n_size + n_implied
263- assert len (pmodel .RV_dims ["rv" ]) == ndim , pmodel .RV_dims
264- assert len (pmodel .RV_dims ["rv" ][:n_size ]) == len (batch_dims )
265- assert len (pmodel .RV_dims ["rv" ][n_size :]) == len (param_dims )
266- if n_implied > 0 :
267- assert pmodel .RV_dims ["rv" ][- 1 ] is None
268253 elif parametrization == "size" :
269254 rv = pm .Normal ("rv" , mu = mu , size = batch_shape + param_shape )
270255 assert rv .eval ().shape == expected_shape
271256 else :
272257 raise NotImplementedError ("Invalid test case parametrization." )
273258
274- @pytest .mark .parametrize ("ellipsis_in" , ["none" , "shape" , "dims" , "both" ])
275- def test_simultaneous_shape_and_dims (self , ellipsis_in ):
259+ def test_simultaneous_shape_and_dims (self ):
276260 with pm .Model () as pmodel :
277261 x = pm .ConstantData ("x" , [1 , 2 , 3 ], dims = "ddata" )
278262
279- if ellipsis_in == "none" :
280- # The shape and dims tuples correspond to each other.
281- # Note: No checks are performed that implied shape (x), shape and dims actually match.
282- y = pm .Normal ("y" , mu = x , shape = (2 , 3 ), dims = ("dshape" , "ddata" ))
283- assert pmodel .RV_dims ["y" ] == ("dshape" , "ddata" )
284- elif ellipsis_in == "shape" :
285- y = pm .Normal ("y" , mu = x , shape = (2 , ...), dims = ("dshape" , "ddata" ))
286- assert pmodel .RV_dims ["y" ] == ("dshape" , "ddata" )
287- elif ellipsis_in == "dims" :
288- y = pm .Normal ("y" , mu = x , shape = (2 , 3 ), dims = ("dshape" , ...))
289- assert pmodel .RV_dims ["y" ] == ("dshape" , None )
290- elif ellipsis_in == "both" :
291- y = pm .Normal ("y" , mu = x , shape = (2 , ...), dims = ("dshape" , ...))
292- assert pmodel .RV_dims ["y" ] == ("dshape" , None )
263+ # The shape and dims tuples correspond to each other.
264+ # Note: No checks are performed that implied shape (x), shape and dims actually match.
265+ y = pm .Normal ("y" , mu = x , shape = (2 , 3 ), dims = ("dshape" , "ddata" ))
266+ assert pmodel .RV_dims ["y" ] == ("dshape" , "ddata" )
293267
294268 assert "dshape" in pmodel .dim_lengths
295269 assert y .eval ().shape == (2 , 3 )
296270
297- @pytest .mark .parametrize ("with_dims_ellipsis" , [False , True ])
298- def test_simultaneous_size_and_dims (self , with_dims_ellipsis ):
271+ def test_simultaneous_size_and_dims (self ):
299272 with pm .Model () as pmodel :
300273 x = pm .ConstantData ("x" , [1 , 2 , 3 ], dims = "ddata" )
301274 assert "ddata" in pmodel .dim_lengths
302275
303276 # Size does not include support dims, so this test must use a dist with support dims.
304277 kwargs = dict (name = "y" , size = (2 , 3 ), mu = at .ones ((3 , 4 )), cov = at .eye (4 ))
305- if with_dims_ellipsis :
306- y = pm .MvNormal (** kwargs , dims = ("dsize" , ...))
307- assert pmodel .RV_dims ["y" ] == ("dsize" , None , None )
308- else :
309- y = pm .MvNormal (** kwargs , dims = ("dsize" , "ddata" , "dsupport" ))
310- assert pmodel .RV_dims ["y" ] == ("dsize" , "ddata" , "dsupport" )
278+ y = pm .MvNormal (** kwargs , dims = ("dsize" , "ddata" , "dsupport" ))
279+ assert pmodel .RV_dims ["y" ] == ("dsize" , "ddata" , "dsupport" )
311280
312281 assert "dsize" in pmodel .dim_lengths
313282 assert y .eval ().shape == (2 , 3 , 4 )
@@ -382,7 +351,6 @@ def test_dist_api_works(self):
382351 pm .Normal .dist (mu = mu , dims = ("town" ,))
383352 assert pm .Normal .dist (mu = mu , shape = (3 ,)).eval ().shape == (3 ,)
384353 assert pm .Normal .dist (mu = mu , shape = (5 , 3 )).eval ().shape == (5 , 3 )
385- assert pm .Normal .dist (mu = mu , shape = (7 , ...)).eval ().shape == (7 , 3 )
386354 assert pm .Normal .dist (mu = mu , size = (3 ,)).eval ().shape == (3 ,)
387355 assert pm .Normal .dist (mu = mu , size = (4 , 3 )).eval ().shape == (4 , 3 )
388356
@@ -408,10 +376,6 @@ def test_mvnormal_shape_size_difference(self):
408376 assert rv .ndim == 3
409377 assert tuple (rv .shape .eval ()) == (5 , 4 , 3 )
410378
411- rv = pm .MvNormal .dist (mu = np .ones ((4 , 3 , 2 )), cov = np .eye (2 ), shape = (6 , 5 , ...))
412- assert rv .ndim == 5
413- assert tuple (rv .shape .eval ()) == (6 , 5 , 4 , 3 , 2 )
414-
415379 rv = pm .MvNormal .dist (mu = [1 , 2 , 3 ], cov = np .eye (3 ), size = (5 , 4 ))
416380 assert tuple (rv .shape .eval ()) == (5 , 4 , 3 )
417381
@@ -422,22 +386,16 @@ def test_convert_dims(self):
422386 assert convert_dims (dims = "town" ) == ("town" ,)
423387 with pytest .raises (ValueError , match = "must be a tuple, str or list" ):
424388 convert_dims (3 )
425- with pytest .raises (ValueError , match = "may only appear in the last position" ):
426- convert_dims (dims = (..., "town" ))
427389
428390 def test_convert_shape (self ):
429391 assert convert_shape (5 ) == (5 ,)
430392 with pytest .raises (ValueError , match = "tuple, TensorVariable, int or list" ):
431393 convert_shape (shape = "notashape" )
432- with pytest .raises (ValueError , match = "may only appear in the last position" ):
433- convert_shape (shape = (3 , ..., 2 ))
434394
435395 def test_convert_size (self ):
436396 assert convert_size (7 ) == (7 ,)
437397 with pytest .raises (ValueError , match = "tuple, TensorVariable, int or list" ):
438398 convert_size (size = "notasize" )
439- with pytest .raises (ValueError , match = "cannot contain" ):
440- convert_size (size = (3 , ...))
441399
442400 def test_lazy_flavors (self ):
443401 assert pm .Uniform .dist (2 , [4 , 5 ], size = [3 , 2 ]).eval ().shape == (3 , 2 )
0 commit comments