@@ -1399,19 +1399,44 @@ def test_issue_3706(self):
13991399
14001400class TestZeroSumNormal :
14011401 @pytest .mark .parametrize (
1402- "dims, zerosum_axes, shape " ,
1402+ "dims, zerosum_axes" ,
14031403 [
1404- (("regions" , "answers" ), None , None ),
1405- (("regions" , "answers" ), 1 , None ),
1406- (("regions" , "answers" ), 2 , None ),
1407- (None , None , (len (COORDS ["regions" ]), len (COORDS ["answers" ]))),
1408- (None , 1 , (len (COORDS ["regions" ]), len (COORDS ["answers" ]))),
1409- (None , 2 , (len (COORDS ["regions" ]), len (COORDS ["answers" ]))),
1404+ (("regions" , "answers" ), None ),
1405+ (("regions" , "answers" ), 1 ),
1406+ (("regions" , "answers" ), 2 ),
14101407 ],
14111408 )
1412- def test_zsn_dims_shape (self , dims , zerosum_axes , shape ):
1409+ def test_zsn_dims (self , dims , zerosum_axes ):
14131410 with pm .Model (coords = COORDS ) as m :
1414- v = pm .ZeroSumNormal ("v" , dims = dims , shape = shape , zerosum_axes = zerosum_axes )
1411+ v = pm .ZeroSumNormal ("v" , dims = dims , zerosum_axes = zerosum_axes )
1412+ s = pm .sample (10 , chains = 1 , tune = 100 )
1413+
1414+ # to test forward graph
1415+ random_samples = pm .draw (v , draws = 10 )
1416+
1417+ assert s .posterior .v .shape == (1 , 10 , len (COORDS ["regions" ]), len (COORDS ["answers" ]))
1418+
1419+ ndim_supp = v .owner .op .ndim_supp
1420+ zerosum_axes = np .arange (- ndim_supp , 0 )
1421+ nonzero_axes = np .arange (v .ndim - ndim_supp )
1422+ for samples in [
1423+ s .posterior .v ,
1424+ random_samples ,
1425+ ]:
1426+ self .assert_zerosum_axes (samples , zerosum_axes )
1427+ self .assert_zerosum_axes (samples , nonzero_axes , check_zerosum_axes = False )
1428+
1429+ @pytest .mark .parametrize (
1430+ "zerosum_axes, shape" ,
1431+ [
1432+ (None , (len (COORDS ["regions" ]), len (COORDS ["answers" ]))),
1433+ (1 , (len (COORDS ["regions" ]), len (COORDS ["answers" ]))),
1434+ (2 , (len (COORDS ["regions" ]), len (COORDS ["answers" ]))),
1435+ ],
1436+ )
1437+ def test_zsn_shape (self , shape , zerosum_axes ):
1438+ with pm .Model (coords = COORDS ) as m :
1439+ v = pm .ZeroSumNormal ("v" , shape = shape , zerosum_axes = zerosum_axes )
14151440 s = pm .sample (10 , chains = 1 , tune = 100 )
14161441
14171442 # to test forward graph
0 commit comments