11module mod_network
22
3- use mod_activation, only: gaussian, gaussian_prime,&
4- relu, relu_prime,&
5- sigmoid, sigmoid_prime,&
6- step, step_prime,&
7- tanhf, tanh_prime
83 use mod_kinds, only: ik, rk
94 use mod_layer, only: array1d, array2d, db_init, dw_init,&
105 db_co_sum, dw_co_sum, layer_type
@@ -19,8 +14,6 @@ module mod_network
1914
2015 type (layer_type), allocatable :: layers(:)
2116 integer , allocatable :: dims(:)
22- procedure (activation_function), pointer , nopass :: activation = > null ()
23- procedure (activation_function), pointer , nopass :: activation_prime = > null ()
2417
2518 contains
2619
@@ -49,14 +42,6 @@ module mod_network
4942 module procedure :: net_constructor
5043 endinterface network_type
5144
52- interface
53- pure function activation_function (x )
54- import :: rk
55- real (rk), intent (in ) :: x(:)
56- real (rk) :: activation_function(size (x))
57- end function activation_function
58- end interface
59-
6045contains
6146
6247 type (network_type) function net_constructor(dims, activation) result(net)
@@ -105,13 +90,13 @@ pure subroutine backprop(self, y, dw, db)
10590 call dw_init(dw, dims)
10691
10792 n = size (dims)
108- db(n) % array = (layers(n) % a - y) * self % activation_prime(layers(n) % z)
93+ db(n) % array = (layers(n) % a - y) * self % layers(n) % activation_prime(layers(n) % z)
10994 dw(n-1 ) % array = matmul (reshape (layers(n-1 ) % a, [dims(n-1 ), 1 ]),&
11095 reshape (db(n) % array, [1 , dims(n)]))
11196
11297 do n = size (dims) - 1 , 2 , - 1
11398 db(n) % array = matmul (layers(n) % w, db(n+1 ) % array)&
114- * self % activation_prime(layers(n) % z)
99+ * self % layers(n) % activation_prime(layers(n) % z)
115100 dw(n-1 ) % array = matmul (reshape (layers(n-1 ) % a, [dims(n-1 ), 1 ]),&
116101 reshape (db(n) % array, [1 , dims(n)]))
117102 end do
@@ -130,7 +115,7 @@ pure subroutine fwdprop(self, x)
130115 layers(1 ) % a = x
131116 do n = 2 , size (layers)
132117 layers(n) % z = matmul (transpose (layers(n-1 ) % w), layers(n-1 ) % a) + layers(n) % b
133- layers(n) % a = self % activation(layers(n) % z)
118+ layers(n) % a = self % layers(n) % activation(layers(n) % z)
134119 end do
135120 end associate
136121 end subroutine fwdprop
@@ -184,9 +169,9 @@ pure function output_single(self, x) result(a)
184169 real (rk), allocatable :: a(:)
185170 integer (ik) :: n
186171 associate(layers = > self % layers)
187- a = self % activation(matmul (transpose (layers(1 ) % w), x) + layers(2 ) % b)
172+ a = self % layers( 2 ) % activation(matmul (transpose (layers(1 ) % w), x) + layers(2 ) % b)
188173 do n = 3 , size (layers)
189- a = self % activation(matmul (transpose (layers(n-1 ) % w), a) + layers(n) % b)
174+ a = self % layers(n) % activation(matmul (transpose (layers(n-1 ) % w), a) + layers(n) % b)
190175 end do
191176 end associate
192177 end function output_single
@@ -223,31 +208,15 @@ subroutine save(self, filename)
223208 end subroutine save
224209
225210 pure subroutine set_activation (self , activation )
226- ! Sets the activation functions. Input string must match one of
227- ! provided activation functions, otherwise it defaults to sigmoid.
228- ! If activation not present, defaults to sigmoid.
211+ ! A thin wrapper around layer % set_activation().
212+ ! This method can be used to set an activation function
213+ ! for all layers at once.
229214 class(network_type), intent (in out ) :: self
230215 character (len=* ), intent (in ) :: activation
231- select case (trim (activation))
232- case (' gaussian' )
233- self % activation = > gaussian
234- self % activation_prime = > gaussian_prime
235- case (' relu' )
236- self % activation = > relu
237- self % activation_prime = > relu_prime
238- case (' sigmoid' )
239- self % activation = > sigmoid
240- self % activation_prime = > sigmoid_prime
241- case (' step' )
242- self % activation = > step
243- self % activation_prime = > step_prime
244- case (' tanh' )
245- self % activation = > tanhf
246- self % activation_prime = > tanh_prime
247- case default
248- self % activation = > sigmoid
249- self % activation_prime = > sigmoid_prime
250- end select
216+ integer :: n
217+ do concurrent(n = 1 :size (self % layers))
218+ call self % layers(n) % set_activation(activation)
219+ end do
251220 end subroutine set_activation
252221
253222 subroutine sync (self , image )
0 commit comments