@@ -23,23 +23,26 @@ module mod_network
2323 procedure , public , pass(self) :: init
2424 procedure , public , pass(self) :: load
2525 procedure , public , pass(self) :: loss
26- procedure , public , pass(self) :: output
26+ procedure , public , pass(self) :: output_batch
27+ procedure , public , pass(self) :: output_single
2728 procedure , public , pass(self) :: save
2829 procedure , public , pass(self) :: set_activation_equal
2930 procedure , public , pass(self) :: set_activation_layers
3031 procedure , public , pass(self) :: sync
3132 procedure , public , pass(self) :: train_batch
33+ procedure , public , pass(self) :: train_epochs
3234 procedure , public , pass(self) :: train_single
3335 procedure , public , pass(self) :: update
3436
37+ generic, public :: output = > output_batch, output_single
3538 generic, public :: set_activation = > set_activation_equal, set_activation_layers
36- generic, public :: train = > train_batch, train_single
39+ generic, public :: train = > train_batch, train_epochs, train_single
3740
3841 end type network_type
3942
4043 interface network_type
4144 module procedure :: net_constructor
42- endinterface network_type
45+ end interface network_type
4346
4447contains
4548
@@ -58,6 +61,7 @@ type(network_type) function net_constructor(dims, activation) result(net)
5861 call net % sync(1 )
5962 end function net_constructor
6063
64+
6165 pure real (rk) function accuracy(self, x, y)
6266 ! Given input x and output y, evaluates the position of the
6367 ! maximum value of the output and returns the number of matches
@@ -74,6 +78,7 @@ pure real(rk) function accuracy(self, x, y)
7478 accuracy = real (good) / size (x, dim= 2 )
7579 end function accuracy
7680
81+
7782 pure subroutine backprop (self , y , dw , db )
7883 ! Applies a backward propagation through the network
7984 ! and returns the weight and bias gradients.
@@ -104,6 +109,7 @@ pure subroutine backprop(self, y, dw, db)
104109
105110 end subroutine backprop
106111
112+
107113 pure subroutine fwdprop (self , x )
108114 ! Performs the forward propagation and stores arguments to activation
109115 ! functions and activations themselves for use in backprop.
@@ -119,6 +125,7 @@ pure subroutine fwdprop(self, x)
119125 end associate
120126 end subroutine fwdprop
121127
128+
122129 subroutine init (self , dims )
123130 ! Allocates and initializes the layers with given dimensions dims.
124131 class(network_type), intent (in out ) :: self
@@ -134,6 +141,7 @@ subroutine init(self, dims)
134141 self % layers(size (dims)) % w = 0
135142 end subroutine init
136143
144+
137145 subroutine load (self , filename )
138146 ! Loads the network from file.
139147 class(network_type), intent (in out ) :: self
@@ -142,32 +150,35 @@ subroutine load(self, filename)
142150 integer (ik), allocatable :: dims(:)
143151 character (len= 100 ) :: buffer ! activation string
144152 open (newunit= fileunit, file= filename, status= ' old' , action= ' read' )
145- read (fileunit, fmt = * ) num_layers
153+ read (fileunit, * ) num_layers
146154 allocate (dims(num_layers))
147- read (fileunit, fmt = * ) dims
155+ read (fileunit, * ) dims
148156 call self % init(dims)
149157 do n = 1 , num_layers
150- read (fileunit, fmt = * ) layer_idx, buffer
158+ read (fileunit, * ) layer_idx, buffer
151159 call self % layers(layer_idx) % set_activation(trim (buffer))
152160 end do
153161 do n = 2 , size (self % dims)
154- read (fileunit, fmt = * ) self % layers(n) % b
162+ read (fileunit, * ) self % layers(n) % b
155163 end do
156164 do n = 1 , size (self % dims) - 1
157- read (fileunit, fmt = * ) self % layers(n) % w
165+ read (fileunit, * ) self % layers(n) % w
158166 end do
159167 close (fileunit)
160168 end subroutine load
161169
170+
162171 pure real (rk) function loss(self, x, y)
163172 ! Given input x and expected output y, returns the loss of the network.
164173 class(network_type), intent (in ) :: self
165174 real (rk), intent (in ) :: x(:), y(:)
166175 loss = 0.5 * sum ((y - self % output(x))** 2 ) / size (x)
167176 end function loss
168177
169- pure function output (self , x ) result(a)
178+
179+ pure function output_single (self , x ) result(a)
170180 ! Use forward propagation to compute the output of the network.
181+ ! This specific procedure is for a single sample of 1-d input data.
171182 class(network_type), intent (in ) :: self
172183 real (rk), intent (in ) :: x(:)
173184 real (rk), allocatable :: a(:)
@@ -178,7 +189,22 @@ pure function output(self, x) result(a)
178189 a = self % layers(n) % activation(matmul (transpose (layers(n-1 ) % w), a) + layers(n) % b)
179190 end do
180191 end associate
181- end function output
192+ end function output_single
193+
194+
195+ pure function output_batch (self , x ) result(a)
196+ ! Use forward propagation to compute the output of the network.
197+ ! This specific procedure is for a batch of 1-d input data.
198+ class(network_type), intent (in ) :: self
199+ real (rk), intent (in ) :: x(:,:)
200+ real (rk), allocatable :: a(:,:)
201+ integer (ik) :: i
202+ allocate (a(self % dims(size (self % dims)), size (x, dim= 2 )))
203+ do i = 1 , size (x, dim= 2 )
204+ a(:,i) = self % output_single(x(:,i))
205+ end do
206+ end function output_batch
207+
182208
183209 subroutine save (self , filename )
184210 ! Saves the network to a file.
@@ -200,6 +226,7 @@ subroutine save(self, filename)
200226 close (fileunit)
201227 end subroutine save
202228
229+
203230 pure subroutine set_activation_equal (self , activation )
204231 ! A thin wrapper around layer % set_activation().
205232 ! This method can be used to set an activation function
@@ -209,6 +236,7 @@ pure subroutine set_activation_equal(self, activation)
209236 call self % layers(:) % set_activation(activation)
210237 end subroutine set_activation_equal
211238
239+
212240 pure subroutine set_activation_layers (self , activation )
213241 ! A thin wrapper around layer % set_activation().
214242 ! This method can be used to set different activation functions
@@ -233,6 +261,7 @@ subroutine sync(self, image)
233261 end do layers
234262 end subroutine sync
235263
264+
236265 subroutine train_batch (self , x , y , eta )
237266 ! Trains a network using input data x and output data y,
238267 ! and learning rate eta. The learning rate is normalized
@@ -273,6 +302,38 @@ subroutine train_batch(self, x, y, eta)
273302
274303 end subroutine train_batch
275304
305+
306+ subroutine train_epochs (self , x , y , eta , num_epochs , batch_size )
307+ ! Trains for num_epochs epochs with mini-bachtes of size equal to batch_size.
308+ class(network_type), intent (in out ) :: self
309+ integer (ik), intent (in ) :: num_epochs, batch_size
310+ real (rk), intent (in ) :: x(:,:), y(:,:), eta
311+
312+ integer (ik) :: i, n, nsamples, nbatch
313+ integer (ik) :: batch_start, batch_end
314+
315+ real (rk) :: pos
316+
317+ nsamples = size (y, dim= 2 )
318+ nbatch = nsamples / batch_size
319+
320+ epochs: do n = 1 , num_epochs
321+ batches: do i = 1 , nbatch
322+
323+ ! pull a random mini-batch from the dataset
324+ call random_number (pos)
325+ batch_start = int (pos * (nsamples - batch_size + 1 ))
326+ if (batch_start == 0 ) batch_start = 1
327+ batch_end = batch_start + batch_size - 1
328+
329+ call self % train(x(:,batch_start:batch_end), y(:,batch_start:batch_end), eta)
330+
331+ end do batches
332+ end do epochs
333+
334+ end subroutine train_epochs
335+
336+
276337 pure subroutine train_single (self , x , y , eta )
277338 ! Trains a network using a single set of input data x and output data y,
278339 ! and learning rate eta.
@@ -285,6 +346,7 @@ pure subroutine train_single(self, x, y, eta)
285346 call self % update(dw, db, eta)
286347 end subroutine train_single
287348
349+
288350 pure subroutine update (self , dw , db , eta )
289351 ! Updates network weights and biases with gradients dw and db,
290352 ! scaled by learning rate eta.
0 commit comments