@@ -36,7 +36,7 @@ class ConcreteDropout(Wrapper):
3636 prob_init: Tuple[float, float].
3737 Probability lower / upper bounds of dropout rate initialization.
3838 temp: float. Temperature.
39- Determines the speed of probability adjustments.
39+ Determines the speed of probability (i.e. dropout rate) adjustments.
4040 seed: Seed for random probability sampling.
4141
4242 # References
@@ -74,6 +74,7 @@ def _concrete_dropout(self, inputs, layer_type):
7474 # Returns
7575 A tensor with the same shape as inputs and dropout applied.
7676 """
77+ assert layer_type in {'dense' , 'conv2d' }
7778 eps = K .cast_to_floatx (K .epsilon ())
7879
7980 noise_shape = K .shape (inputs )
@@ -93,6 +94,7 @@ def _concrete_dropout(self, inputs, layer_type):
9394 )
9495 drop_prob = K .sigmoid (drop_prob / self .temp )
9596
97+ # apply dropout
9698 random_tensor = 1. - drop_prob
9799 retain_prob = 1. - self .p
98100 inputs *= random_tensor
@@ -104,7 +106,7 @@ def build(self, input_shape=None):
104106 input_shape = to_tuple (input_shape )
105107 if len (input_shape ) == 2 : # Dense_layer
106108 input_dim = np .prod (input_shape [- 1 ]) # we drop only last dim
107- elif len (input_shape ) == 4 : # Conv_layer
109+ elif len (input_shape ) == 4 : # Conv2D_layer
108110 input_dim = (input_shape [1 ]
109111 if K .image_data_format () == 'channels_first'
110112 else input_shape [3 ]) # we drop only channels
@@ -129,7 +131,7 @@ def build(self, input_shape=None):
129131
130132 super (ConcreteDropout , self ).build (input_shape )
131133
132- # initialize regularizer / prior KL term
134+ # initialize regularizer / prior KL term and add to layer-loss
133135 weight = self .layer .kernel
134136 kernel_regularizer = (
135137 self .weight_regularizer
@@ -146,9 +148,7 @@ def build(self, input_shape=None):
146148 def call (self , inputs , training = None ):
147149 def relaxed_dropped_inputs ():
148150 return self .layer .call (self ._concrete_dropout (inputs , (
149- 'dense'
150- if len (K .int_shape (inputs )) == 2
151- else 'conv2d'
151+ 'dense' if len (K .int_shape (inputs )) == 2 else 'conv2d'
152152 )))
153153
154154 return K .in_train_phase (relaxed_dropped_inputs ,
0 commit comments