@@ -63,7 +63,6 @@ def __init__(
6363
6464 self .seed_generator = seed_generator or keras .random .SeedGenerator ()
6565
66- self .log_normalization_constant = None
6766 self .dim = None
6867 self ._loc = None
6968 self ._scale = None
@@ -78,21 +77,16 @@ def build(self, input_shape: Shape) -> None:
7877 self .loc = ops .cast (ops .broadcast_to (self .loc , (self .dim ,)), "float32" )
7978 self .scale = ops .cast (ops .broadcast_to (self .scale , (self .dim ,)), "float32" )
8079
81- self .log_normalization_constant = (
82- - 0.5 * self .dim * math .log (self .df )
83- - 0.5 * self .dim * math .log (math .pi )
84- - math .lgamma (0.5 * self .df )
85- + math .lgamma (0.5 * (self .df + self .dim ))
86- - ops .sum (keras .ops .log (self .scale ))
87- )
88-
8980 if self .trainable_parameters :
9081 self ._loc = self .add_weight (
91- shape = ops .shape (self .loc ), initializer = keras .initializers .get (self .loc ), dtype = "float32" , trainable = True
82+ shape = ops .shape (self .loc ),
83+ initializer = keras .initializers .get (keras .ops .copy (self .loc )),
84+ dtype = "float32" ,
85+ trainable = True ,
9286 )
9387 self ._scale = self .add_weight (
9488 shape = ops .shape (self .scale ),
95- initializer = keras .initializers .get (self .scale ),
89+ initializer = keras .initializers .get (keras . ops . copy ( self .scale ) ),
9690 dtype = "float32" ,
9791 trainable = True ,
9892 )
@@ -105,7 +99,14 @@ def log_prob(self, samples: Tensor, *, normalize: bool = True) -> Tensor:
10599 result = - 0.5 * (self .df + self .dim ) * ops .log1p (mahalanobis_term / self .df )
106100
107101 if normalize :
108- result += self .log_normalization_constant
102+ log_normalization_constant = (
103+ - 0.5 * self .dim * math .log (self .df )
104+ - 0.5 * self .dim * math .log (math .pi )
105+ - math .lgamma (0.5 * self .df )
106+ + math .lgamma (0.5 * (self .df + self .dim ))
107+ - ops .sum (keras .ops .log (self ._scale ))
108+ )
109+ result += log_normalization_constant
109110
110111 return result
111112
0 commit comments