From 70782619dc756c750fb117e145a097a994677f45 Mon Sep 17 00:00:00 2001 From: Tarun S Paparaju Date: Fri, 8 Feb 2019 18:59:01 +0530 Subject: [PATCH 1/7] Create dropconnect.py --- keras_contrib/wrappers/dropconnect.py | 1 + 1 file changed, 1 insertion(+) create mode 100644 keras_contrib/wrappers/dropconnect.py diff --git a/keras_contrib/wrappers/dropconnect.py b/keras_contrib/wrappers/dropconnect.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/keras_contrib/wrappers/dropconnect.py @@ -0,0 +1 @@ + From 94c3b1283d359df911c6bc3c7583dcd43bc887d1 Mon Sep 17 00:00:00 2001 From: Tarun S Paparaju Date: Fri, 8 Feb 2019 19:08:41 +0530 Subject: [PATCH 2/7] Add DropConnect Wrapper --- keras_contrib/wrappers/dropconnect.py | 53 +++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/keras_contrib/wrappers/dropconnect.py b/keras_contrib/wrappers/dropconnect.py index 8b1378917..6500b53e1 100644 --- a/keras_contrib/wrappers/dropconnect.py +++ b/keras_contrib/wrappers/dropconnect.py @@ -1 +1,54 @@ +from keras import backend as K +from keras import activations +from keras import initializers +from keras import regularizers +from keras import constraints +from keras.layers import InputSpec +from keras.layers import Dense +from keras.layers import Layer +from keras.layers import Wrapper + +class DropConnect(Wrapper): + """ + An implementation of DropConnect wrapper in Keras. + This layer drops connections between a one layer and + the next layer randomly with a given probability (rather + than dropping activations as in classic Dropout). + + This wrapper can be used to drop the connections from + any Keras layer (Dense, LSTM etc) + + #Example usage + dense = DropConnect(Dense(10, activation='sigmoid'), prob=0.05) + lstm = DropConnect(LSTM(20, activation='relu'), prob=0.2) + + #Arguments + layer : Any Keras layer (instance of Layer class) + prob : dropout rate (probability) + + #References + https://github.com/andry9454/KerasDropconnect/blob/master/ddrop/layers.py + """ + def __init__(self, layer, prob=0.1, **kwargs): + self.prob = prob + self.layer = layer + super(DropConnect, self).__init__(layer, **kwargs) + if 0. < self.prob < 1.: + self.uses_learning_phase = True + + def build(self, input_shape): + if not self.layer.built: + self.layer.build(input_shape) + self.layer.built = True + + def compute_output_shape(self, input_shape): + return self.layer.compute_output_shape(input_shape) + + def call(self, x): + if 0. < self.prob < 1.: + self.layer.kernel = K.in_train_phase(K.dropout(self.layer.kernel, self.prob), + self.layer.kernel) + self.layer.bias = K.in_train_phase(K.dropout(self.layer.bias, self.prob), + self.layer.bias) + return self.layer.call(x) From 8c68ca1085d305c29e5c8e2fb8ef3ea3b9bf4d25 Mon Sep 17 00:00:00 2001 From: Tarun S Paparaju Date: Fri, 8 Feb 2019 19:10:18 +0530 Subject: [PATCH 3/7] Add DropConnect import in __init__.py --- keras_contrib/wrappers/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/keras_contrib/wrappers/__init__.py b/keras_contrib/wrappers/__init__.py index e69de29bb..34d400e9d 100644 --- a/keras_contrib/wrappers/__init__.py +++ b/keras_contrib/wrappers/__init__.py @@ -0,0 +1 @@ +from .dropconnect import DropConnect From 865ec846c8868b657c313aa740cd5f63154563d1 Mon Sep 17 00:00:00 2001 From: Tarun S Paparaju Date: Fri, 8 Feb 2019 19:16:37 +0530 Subject: [PATCH 4/7] Remove unused imports and add necessary ones --- keras_contrib/wrappers/dropconnect.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/keras_contrib/wrappers/dropconnect.py b/keras_contrib/wrappers/dropconnect.py index 6500b53e1..2453cd977 100644 --- a/keras_contrib/wrappers/dropconnect.py +++ b/keras_contrib/wrappers/dropconnect.py @@ -1,11 +1,4 @@ from keras import backend as K -from keras import activations -from keras import initializers -from keras import regularizers -from keras import constraints -from keras.layers import InputSpec -from keras.layers import Dense -from keras.layers import Layer from keras.layers import Wrapper From c11ec70dc00d323a25a3bb56e9c0908fa21efee0 Mon Sep 17 00:00:00 2001 From: Tarun S Paparaju Date: Fri, 8 Feb 2019 19:28:00 +0530 Subject: [PATCH 5/7] Change import --- keras_contrib/wrappers/dropconnect.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_contrib/wrappers/dropconnect.py b/keras_contrib/wrappers/dropconnect.py index 2453cd977..fd65dbe18 100644 --- a/keras_contrib/wrappers/dropconnect.py +++ b/keras_contrib/wrappers/dropconnect.py @@ -1,5 +1,5 @@ from keras import backend as K -from keras.layers import Wrapper +from keras.layers.wrappers import Wrapper class DropConnect(Wrapper): From 59ed647567a35760e8b780c53c3461006d3b47e2 Mon Sep 17 00:00:00 2001 From: Tarun S Paparaju Date: Fri, 8 Feb 2019 19:34:00 +0530 Subject: [PATCH 6/7] Fix pep8 violations --- keras_contrib/wrappers/dropconnect.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/keras_contrib/wrappers/dropconnect.py b/keras_contrib/wrappers/dropconnect.py index fd65dbe18..2da050513 100644 --- a/keras_contrib/wrappers/dropconnect.py +++ b/keras_contrib/wrappers/dropconnect.py @@ -40,8 +40,10 @@ def compute_output_shape(self, input_shape): def call(self, x): if 0. < self.prob < 1.: - self.layer.kernel = K.in_train_phase(K.dropout(self.layer.kernel, self.prob), + self.layer.kernel = K.in_train_phase(K.dropout(self.layer.kernel, + self.prob), self.layer.kernel) - self.layer.bias = K.in_train_phase(K.dropout(self.layer.bias, self.prob), + self.layer.bias = K.in_train_phase(K.dropout(self.layer.bias, + self.prob), self.layer.bias) return self.layer.call(x) From 62094bde24bda260ec1f712aefcd8a4d5c4f1705 Mon Sep 17 00:00:00 2001 From: Tarun S Paparaju Date: Sat, 9 Feb 2019 22:15:57 +0530 Subject: [PATCH 7/7] Add get_config function --- keras_contrib/wrappers/dropconnect.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/keras_contrib/wrappers/dropconnect.py b/keras_contrib/wrappers/dropconnect.py index 2da050513..170d50cea 100644 --- a/keras_contrib/wrappers/dropconnect.py +++ b/keras_contrib/wrappers/dropconnect.py @@ -10,7 +10,7 @@ class DropConnect(Wrapper): than dropping activations as in classic Dropout). This wrapper can be used to drop the connections from - any Keras layer (Dense, LSTM etc) + any Keras layer with weights and biases (Dense, LSTM etc) #Example usage dense = DropConnect(Dense(10, activation='sigmoid'), prob=0.05) @@ -47,3 +47,6 @@ def call(self, x): self.prob), self.layer.bias) return self.layer.call(x) + + def get_config(self): + return self.layer.get_config(self)