@@ -29,13 +29,16 @@ class ConcreteDropout(Wrapper):
2929
3030 # Arguments
3131 layer: The to be wrapped layer.
32- n_data: int. Length of the dataset.
33- length_scale: float. Prior lengthscale.
34- model_precision: float. Model precision parameter is `1` for classification.
32+ n_data: int. `n_data > 0`.
33+ Length of the dataset.
34+ length_scale: float. `length_scale > 0`.
35+ Prior lengthscale.
36+ model_precision: float. `model_precision > 0`.
37+ Model precision parameter is `1` for classification.
3538 Also known as inverse observation noise.
36- prob_init: Tuple[float, float].
39+ prob_init: Tuple[float, float]. `prob_init > 0`
3740 Probability lower / upper bounds of dropout rate initialization.
38- temp: float. Temperature.
41+ temp: float. Temperature. `temp > 0`.
3942 Determines the speed of probability (i.e. dropout rate) adjustments.
4043 seed: Seed for random probability sampling.
4144
@@ -53,13 +56,23 @@ def __init__(self,
5356 seed = None ,
5457 ** kwargs ):
5558 assert 'kernel_regularizer' not in kwargs
59+ assert n_data > 0 and isinstance (n_data , int )
60+ assert length_scale > 0.
61+ assert prob_init [0 ] <= prob_init [1 ] and prob_init [0 ] > 0.
62+ assert temp > 0.
63+ assert model_precision > 0.
5664 super (ConcreteDropout , self ).__init__ (layer , ** kwargs )
57- self .weight_regularizer = length_scale ** 2 / (model_precision * n_data )
58- self .dropout_regularizer = 2 / (model_precision * n_data )
59- self .prob_init = tuple (np .log (prob_init ))
60- self .temp = temp
61- self .seed = seed
6265
66+ self ._n_data = n_data
67+ self ._length_scale = length_scale
68+ self ._model_precision = model_precision
69+ self ._prob_init = prob_init
70+ self ._temp = temp
71+ self ._seed = seed
72+
73+ eps = K .epsilon ()
74+ self .weight_regularizer = length_scale ** 2 / (model_precision * n_data + eps )
75+ self .dropout_regularizer = 2 / (model_precision * n_data + eps )
6376 self .supports_masking = True
6477 self .p_logit = None
6578 self .p = None
@@ -84,15 +97,15 @@ def _concrete_dropout(self, inputs, layer_type):
8497 else :
8598 noise_shape = (noise_shape [0 ], 1 , 1 , noise_shape [3 ])
8699 unif_noise = K .random_uniform (shape = noise_shape ,
87- seed = self .seed ,
100+ seed = self ._seed ,
88101 dtype = inputs .dtype )
89102 drop_prob = (
90103 K .log (self .p + eps )
91104 - K .log (1. - self .p + eps )
92105 + K .log (unif_noise + eps )
93106 - K .log (1. - unif_noise + eps )
94107 )
95- drop_prob = K .sigmoid (drop_prob / self .temp )
108+ drop_prob = K .sigmoid (drop_prob / self ._temp )
96109
97110 # apply dropout
98111 random_tensor = 1. - drop_prob
@@ -123,8 +136,8 @@ def build(self, input_shape=None):
123136 self .p_logit = self .layer .add_weight (name = 'p_logit' ,
124137 shape = (1 ,),
125138 initializer = RandomUniform (
126- * self .prob_init ,
127- seed = self .seed
139+ * np . log ( self ._prob_init ) ,
140+ seed = self ._seed
128141 ),
129142 trainable = True )
130143 self .p = K .squeeze (K .sigmoid (self .p_logit ), axis = 0 )
@@ -156,11 +169,12 @@ def relaxed_dropped_inputs():
156169 training = training )
157170
158171 def get_config (self ):
159- config = {'weight_regularizer' : self .weight_regularizer ,
160- 'dropout_regularizer' : self .dropout_regularizer ,
161- 'prob_init' : tuple (np .round (self .prob_init , 8 )),
162- 'temp' : self .temp ,
163- 'seed' : self .seed }
172+ config = {'n_data' : self ._n_data ,
173+ 'length_scale' : self ._length_scale ,
174+ 'model_precision' : self ._model_precision ,
175+ 'prob_init' : self ._prob_init ,
176+ 'temp' : self ._temp ,
177+ 'seed' : self ._seed }
164178 base_config = super (ConcreteDropout , self ).get_config ()
165179 return dict (list (base_config .items ()) + list (config .items ()))
166180
0 commit comments