@@ -164,6 +164,7 @@ end function loss
164164
165165 pure function output_single (self , x ) result(a)
166166 ! Use forward propagation to compute the output of the network.
167+ ! This specific procedure is for a single sample of 1-d input data.
167168 class(network_type), intent (in ) :: self
168169 real (rk), intent (in ) :: x(:)
169170 real (rk), allocatable :: a(:)
@@ -177,17 +178,16 @@ pure function output_single(self, x) result(a)
177178 end function output_single
178179
179180 pure function output_batch (self , x ) result(a)
181+ ! Use forward propagation to compute the output of the network.
182+ ! This specific procedure is for a batch of 1-d input data.
180183 class(network_type), intent (in ) :: self
181184 real (rk), intent (in ) :: x(:,:)
182185 real (rk), allocatable :: a(:,:)
183-
184186 integer (ik) :: i
185-
186- allocate (a(self% dims(size (self% dims)),size (x,dim= 2 )))
187+ allocate (a(self % dims(size (self % dims)), size (x, dim= 2 )))
187188 do i = 1 , size (x, dim= 2 )
188- a(:,i)= self% output(x(:,i))
189- enddo
190-
189+ a(:,i) = self % output_single(x(:,i))
190+ end do
191191 end function output_batch
192192
193193 subroutine save (self , filename )
@@ -272,31 +272,30 @@ subroutine train_batch(self, x, y, eta)
272272
273273 end subroutine train_batch
274274
275- subroutine train_epochs (self , x , y , eta ,num_epochs ,num_batch_size )
276- ! Performs the training for nun_epochs epochs with mini-bachtes of size equal to num_batch_size
275+ subroutine train_epochs (self , x , y , eta , num_epochs , batch_size )
276+ ! Trains for num_epochs epochs with mini-bachtes of size equal to batch_size.
277277 class(network_type), intent (in out ) :: self
278- integer (ik),intent (in ):: num_epochs,num_batch_size
278+ integer (ik), intent (in ) :: num_epochs, batch_size
279279 real (rk), intent (in ) :: x(:,:), y(:,:), eta
280280
281- integer (ik):: i,n,nsamples,nbatch
282- integer (ik):: batch_start,batch_end
283-
284- real (rk):: pos
281+ integer (ik) :: i, n, nsamples, nbatch
282+ integer (ik) :: batch_start, batch_end
285283
286- nsamples = size (y,dim = 2 )
284+ real (rk) :: pos
287285
288- nbatch= nsamples/ num_batch_size
286+ nsamples = size (y, dim= 2 )
287+ nbatch = nsamples / batch_size
289288
290- epoch: do n= 1 , num_epochs
291- mini_batches: do i= 1 , nbatch
289+ epoch: do n = 1 , num_epochs
290+ mini_batches: do i = 1 , nbatch
292291
293292 ! pull a random mini-batch from the dataset
294293 call random_number (pos)
295- batch_start= int (pos* (nsamples- num_batch_size + 1 ))
296- if (batch_start.eq. 0 ) batch_start= 1
297- batch_end= batch_start+ num_batch_size - 1
294+ batch_start = int (pos * (nsamples - batch_size + 1 ))
295+ if (batch_start == 0 ) batch_start = 1
296+ batch_end = batch_start + batch_size - 1
298297
299- call self% train(x(:,batch_start:batch_end),y(:,batch_start:batch_end),eta)
298+ call self % train(x(:,batch_start:batch_end), y(:,batch_start:batch_end), eta)
300299
301300 enddo mini_batches
302301 enddo epoch
0 commit comments