@@ -2468,12 +2468,7 @@ def __new__(cls, *args, zerosum_axes=None, support_shape=None, dims=None, **kwar
24682468
24692469 @classmethod
24702470 def dist (cls , sigma = 1 , zerosum_axes = None , support_shape = None , ** kwargs ):
2471- if zerosum_axes is None :
2472- zerosum_axes = 1
2473- if not isinstance (zerosum_axes , int ):
2474- raise TypeError ("zerosum_axes has to be an integer" )
2475- if not zerosum_axes > 0 :
2476- raise ValueError ("zerosum_axes has to be > 0" )
2471+ zerosum_axes = cls .check_zerosum_axes (zerosum_axes )
24772472
24782473 sigma = at .as_tensor_variable (floatX (sigma ))
24792474 if sigma .ndim > 0 :
@@ -2501,21 +2496,6 @@ def dist(cls, sigma=1, zerosum_axes=None, support_shape=None, **kwargs):
25012496 [sigma ], zerosum_axes = zerosum_axes , support_shape = support_shape , ** kwargs
25022497 )
25032498
2504- # TODO: This is if we want ZeroSum constraint on other dists than Normal
2505- # def dist(cls, dist, lower, upper, **kwargs):
2506- # if not isinstance(dist, TensorVariable) or not isinstance(
2507- # dist.owner.op, (RandomVariable, SymbolicRandomVariable)
2508- # ):
2509- # raise ValueError(
2510- # f"Censoring dist must be a distribution created via the `.dist()` API, got {type(dist)}"
2511- # )
2512- # if dist.owner.op.ndim_supp > 0:
2513- # raise NotImplementedError(
2514- # "Censoring of multivariate distributions has not been implemented yet"
2515- # )
2516- # check_dist_not_registered(dist)
2517- # return super().dist([dist, lower, upper], **kwargs)
2518-
25192499 @classmethod
25202500 def check_zerosum_axes (cls , zerosum_axes : Optional [int ]) -> int :
25212501 if zerosum_axes is None :
0 commit comments