@@ -1401,12 +1401,12 @@ class TestZeroSumNormal:
14011401 @pytest .mark .parametrize (
14021402 "dims, zerosum_axes, shape" ,
14031403 [
1404- (("regions" , "answers" ), "answers" , None ),
1405- (("regions" , "answers" ), ( "regions" , "answers" ) , None ),
1406- (("regions" , "answers" ), 0 , None ),
1407- (( "regions" , "answers" ), - 1 , None ),
1408- (( "regions" , "answers" ), ( 0 , 1 ), None ),
1409- (None , - 2 , (len (COORDS ["regions" ]), len (COORDS ["answers" ]))),
1404+ (("regions" , "answers" ), None , None ),
1405+ (("regions" , "answers" ), 1 , None ),
1406+ (("regions" , "answers" ), 2 , None ),
1407+ (None , None , ( len ( COORDS [ "regions" ]), len ( COORDS [ "answers" ])) ),
1408+ (None , 1 , ( len ( COORDS [ "regions" ]), len ( COORDS [ "answers" ])) ),
1409+ (None , 2 , (len (COORDS ["regions" ]), len (COORDS ["answers" ]))),
14101410 ],
14111411 )
14121412 def test_zsn_dims_shape (self , dims , zerosum_axes , shape ):
@@ -1419,41 +1419,27 @@ def test_zsn_dims_shape(self, dims, zerosum_axes, shape):
14191419
14201420 assert s .posterior .v .shape == (1 , 10 , len (COORDS ["regions" ]), len (COORDS ["answers" ]))
14211421
1422- if not isinstance (zerosum_axes , (list , tuple )):
1423- zerosum_axes = [zerosum_axes ]
1422+ zerosum_axes = np .arange (- v .owner .op .ndim_supp , 0 )
1423+ nonzero_axes = np .arange (v .ndim - v .owner .op .ndim_supp )
1424+
1425+ for ax in zerosum_axes :
1426+ for samples in [
1427+ s .posterior .v .mean (axis = ax ),
1428+ random_samples .mean (axis = ax ),
1429+ ]:
1430+ assert np .isclose (
1431+ samples , 0
1432+ ).all (), f"{ ax } is a zerosum_axis but is not summing to 0 across all samples."
14241433
1425- if isinstance ( zerosum_axes [ 0 ], str ) :
1426- for ax in zerosum_axes :
1434+ if nonzero_axes :
1435+ for ax in nonzero_axes :
14271436 for samples in [
1428- s .posterior .v .mean (dim = ax ),
1429- random_samples .mean (axis = dims . index ( ax ) + 1 ),
1437+ s .posterior .v .mean (axis = ax ),
1438+ random_samples .mean (axis = ax ),
14301439 ]:
1431- assert np .isclose (
1440+ assert not np .isclose (
14321441 samples , 0
1433- ).all (), f"{ ax } is a zerosum_axis but is not summing to 0 across all samples."
1434-
1435- nonzero_axes = list (set (dims ).difference (zerosum_axes ))
1436- if nonzero_axes :
1437- for ax in nonzero_axes :
1438- for samples in [
1439- s .posterior .v .mean (dim = ax ),
1440- random_samples .mean (axis = dims .index (ax ) + 1 ),
1441- ]:
1442- assert not np .isclose (
1443- samples , 0
1444- ).all (), f"{ ax } is not a zerosum_axis, but is nonetheless summing to 0 across all samples."
1445-
1446- else :
1447- for ax in zerosum_axes :
1448- if ax < 0 :
1449- assert np .isclose (
1450- s .posterior .v .mean (axis = ax ), 0
1451- ).all (), f"{ ax } is a zerosum_axis but is not summing to 0 across all samples."
1452- else :
1453- ax = ax + 2 # because 'chain' and 'draw' are added as new axes after sampling
1454- assert np .isclose (
1455- s .posterior .v .mean (axis = ax ), 0
1456- ).all (), f"{ ax } is a zerosum_axis but is not summing to 0 across all samples."
1442+ ).all (), f"{ ax } is not a zerosum_axis, but is nonetheless summing to 0 across all samples."
14571443
14581444 @pytest .mark .parametrize (
14591445 "dims, zerosum_axes" ,
0 commit comments