@@ -26,23 +26,22 @@ module mod_network
2626
2727 procedure , public , pass(self) :: accuracy
2828 procedure , public , pass(self) :: backprop
29- procedure , public , pass(self) :: fit_batch
3029 procedure , public , pass(self) :: fwdprop
3130 procedure , public , pass(self) :: init
3231 procedure , public , pass(self) :: load
3332 procedure , public , pass(self) :: loss
34- procedure , public , pass(self) :: output
35- procedure , public , pass(self) :: predict_batch
33+ procedure , public , pass(self) :: output_batch
34+ procedure , public , pass(self) :: output_single
3635 procedure , public , pass(self) :: save
3736 procedure , public , pass(self) :: set_activation
3837 procedure , public , pass(self) :: sync
3938 procedure , public , pass(self) :: train_batch
39+ procedure , public , pass(self) :: train_epochs
4040 procedure , public , pass(self) :: train_single
4141 procedure , public , pass(self) :: update
4242
43- generic, public :: fit = > fit_batch! , train_single
44- generic, public :: predict = > predict_batch! , train_single
45- generic, public :: train = > train_batch, train_single
43+ generic, public :: output = > output_batch, output_single
44+ generic, public :: train = > train_batch, train_epochs, train_single
4645
4746 end type network_type
4847
@@ -121,44 +120,6 @@ pure subroutine backprop(self, y, dw, db)
121120
122121 end subroutine backprop
123122
124- subroutine fit_batch (self , x , y , eta ,epochs ,batch_size )
125- ! Performs the training for n epochs with mini-bachtes of size equal to batch_size
126- class(network_type), intent (in out ) :: self
127- integer (ik),intent (in ),optional :: epochs,batch_size
128- real (rk), intent (in ) :: x(:,:), y(:,:), eta
129-
130- integer (ik):: i,n,nsamples,nbatch
131- integer (ik):: num_epochs,num_batch_size
132- integer (ik):: batch_start,batch_end
133-
134- real (rk):: pos
135-
136- nsamples= size (y,dim= 2 )
137-
138- num_epochs= 1
139- if (present (epochs))num_epochs= epochs
140-
141- num_batch_size= nsamples
142- if (present (batch_size))num_batch_size= batch_size
143-
144- nbatch= nsamples/ num_batch_size
145-
146- epoch: do n= 1 ,num_epochs
147- mini_batches: do i= 1 ,nbatch
148-
149- ! pull a random mini-batch from the dataset
150- call random_number (pos)
151- batch_start= int (pos* (nsamples- num_batch_size+1 ))
152- if (batch_start.eq. 0 )batch_start= 1
153- batch_end= batch_start+ batch_size-1
154-
155- call self% train(x(:,batch_start:batch_end),y(:,batch_start:batch_end),eta)
156-
157- enddo mini_batches
158- enddo epoch
159-
160- end subroutine
161-
162123 pure subroutine fwdprop (self , x )
163124 ! Performs the forward propagation and stores arguments to activation
164125 ! functions and activations themselves for use in backprop.
@@ -216,7 +177,7 @@ pure real(rk) function loss(self, x, y)
216177 loss = 0.5 * sum ((y - self % output(x))** 2 ) / size (x)
217178 end function loss
218179
219- pure function output (self , x ) result(a)
180+ pure function output_single (self , x ) result(a)
220181 ! Use forward propagation to compute the output of the network.
221182 class(network_type), intent (in ) :: self
222183 real (rk), intent (in ) :: x(:)
@@ -228,9 +189,9 @@ pure function output(self, x) result(a)
228189 a = self % activation(matmul (transpose (layers(n-1 ) % w), a) + layers(n) % b)
229190 end do
230191 end associate
231- end function output
192+ end function output_single
232193
233- pure function predict_batch (self , x ) result(a)
194+ pure function output_batch (self , x ) result(a)
234195 class(network_type), intent (in ) :: self
235196 real (rk), intent (in ) :: x(:,:)
236197 real (rk), allocatable :: a(:,:)
@@ -242,7 +203,7 @@ pure function predict_batch(self, x) result(a)
242203 a(:,i)= self% output(x(:,i))
243204 enddo
244205
245- end function predict_batch
206+ end function output_batch
246207
247208 subroutine save (self , filename )
248209 ! Saves the network to a file.
@@ -342,6 +303,37 @@ subroutine train_batch(self, x, y, eta)
342303
343304 end subroutine train_batch
344305
306+ subroutine train_epochs (self , x , y , eta ,num_epochs ,num_batch_size )
307+ ! Performs the training for nun_epochs epochs with mini-bachtes of size equal to num_batch_size
308+ class(network_type), intent (in out ) :: self
309+ integer (ik),intent (in ):: num_epochs,num_batch_size
310+ real (rk), intent (in ) :: x(:,:), y(:,:), eta
311+
312+ integer (ik):: i,n,nsamples,nbatch
313+ integer (ik):: batch_start,batch_end
314+
315+ real (rk):: pos
316+
317+ nsamples= size (y,dim= 2 )
318+
319+ nbatch= nsamples/ num_batch_size
320+
321+ epoch: do n= 1 ,num_epochs
322+ mini_batches: do i= 1 ,nbatch
323+
324+ ! pull a random mini-batch from the dataset
325+ call random_number (pos)
326+ batch_start= int (pos* (nsamples- num_batch_size+1 ))
327+ if (batch_start.eq. 0 )batch_start= 1
328+ batch_end= batch_start+ num_batch_size-1
329+
330+ call self% train(x(:,batch_start:batch_end),y(:,batch_start:batch_end),eta)
331+
332+ enddo mini_batches
333+ enddo epoch
334+
335+ end subroutine train_epochs
336+
345337 pure subroutine train_single (self , x , y , eta )
346338 ! Trains a network using a single set of input data x and output data y,
347339 ! and learning rate eta.
0 commit comments