11module mod_network
22
3- use mod_activation, only: gaussian, gaussian_prime,&
4- relu, relu_prime,&
5- sigmoid, sigmoid_prime,&
6- step, step_prime,&
7- tanhf, tanh_prime
83 use mod_kinds, only: ik, rk
94 use mod_layer, only: array1d, array2d, db_init, dw_init,&
105 db_co_sum, dw_co_sum, layer_type
@@ -19,8 +14,6 @@ module mod_network
1914
2015 type (layer_type), allocatable :: layers(:)
2116 integer , allocatable :: dims(:)
22- procedure (activation_function), pointer , nopass :: activation = > null ()
23- procedure (activation_function), pointer , nopass :: activation_prime = > null ()
2417
2518 contains
2619
@@ -46,14 +39,6 @@ module mod_network
4639 module procedure :: net_constructor
4740 endinterface network_type
4841
49- interface
50- pure function activation_function (x )
51- import :: rk
52- real (rk), intent (in ) :: x(:)
53- real (rk) :: activation_function(size (x))
54- end function activation_function
55- end interface
56-
5742contains
5843
5944 type (network_type) function net_constructor(dims, activation) result(net)
@@ -102,13 +87,13 @@ pure subroutine backprop(self, y, dw, db)
10287 call dw_init(dw, dims)
10388
10489 n = size (dims)
105- db(n) % array = (layers(n) % a - y) * self % activation_prime(layers(n) % z)
90+ db(n) % array = (layers(n) % a - y) * self % layers(n) % activation_prime(layers(n) % z)
10691 dw(n-1 ) % array = matmul (reshape (layers(n-1 ) % a, [dims(n-1 ), 1 ]),&
10792 reshape (db(n) % array, [1 , dims(n)]))
10893
10994 do n = size (dims) - 1 , 2 , - 1
11095 db(n) % array = matmul (layers(n) % w, db(n+1 ) % array)&
111- * self % activation_prime(layers(n) % z)
96+ * self % layers(n) % activation_prime(layers(n) % z)
11297 dw(n-1 ) % array = matmul (reshape (layers(n-1 ) % a, [dims(n-1 ), 1 ]),&
11398 reshape (db(n) % array, [1 , dims(n)]))
11499 end do
@@ -127,7 +112,7 @@ pure subroutine fwdprop(self, x)
127112 layers(1 ) % a = x
128113 do n = 2 , size (layers)
129114 layers(n) % z = matmul (transpose (layers(n-1 ) % w), layers(n-1 ) % a) + layers(n) % b
130- layers(n) % a = self % activation(layers(n) % z)
115+ layers(n) % a = self % layers(n) % activation(layers(n) % z)
131116 end do
132117 end associate
133118 end subroutine fwdprop
@@ -181,9 +166,9 @@ pure function output(self, x) result(a)
181166 real (rk), allocatable :: a(:)
182167 integer (ik) :: n
183168 associate(layers = > self % layers)
184- a = self % activation(matmul (transpose (layers(1 ) % w), x) + layers(2 ) % b)
169+ a = self % layers( 2 ) % activation(matmul (transpose (layers(1 ) % w), x) + layers(2 ) % b)
185170 do n = 3 , size (layers)
186- a = self % activation(matmul (transpose (layers(n-1 ) % w), a) + layers(n) % b)
171+ a = self % layers(n) % activation(matmul (transpose (layers(n-1 ) % w), a) + layers(n) % b)
187172 end do
188173 end associate
189174 end function output
@@ -206,31 +191,15 @@ subroutine save(self, filename)
206191 end subroutine save
207192
208193 pure subroutine set_activation (self , activation )
209- ! Sets the activation functions. Input string must match one of
210- ! provided activation functions, otherwise it defaults to sigmoid.
211- ! If activation not present, defaults to sigmoid.
194+ ! A thin wrapper around layer % set_activation().
195+ ! This method can be used to set an activation function
196+ ! for all layers at once.
212197 class(network_type), intent (in out ) :: self
213198 character (len=* ), intent (in ) :: activation
214- select case (trim (activation))
215- case (' gaussian' )
216- self % activation = > gaussian
217- self % activation_prime = > gaussian_prime
218- case (' relu' )
219- self % activation = > relu
220- self % activation_prime = > relu_prime
221- case (' sigmoid' )
222- self % activation = > sigmoid
223- self % activation_prime = > sigmoid_prime
224- case (' step' )
225- self % activation = > step
226- self % activation_prime = > step_prime
227- case (' tanh' )
228- self % activation = > tanhf
229- self % activation_prime = > tanh_prime
230- case default
231- self % activation = > sigmoid
232- self % activation_prime = > sigmoid_prime
233- end select
199+ integer :: n
200+ do concurrent(n = 1 :size (self % layers))
201+ call self % layers(n) % set_activation(activation)
202+ end do
234203 end subroutine set_activation
235204
236205 subroutine sync (self , image )
0 commit comments