@@ -122,6 +122,7 @@ pure subroutine backprop(self, y, dw, db)
122122 end subroutine backprop
123123
124124 subroutine fit_batch (self , x , y , eta ,epochs ,batch_size )
125+ ! Performs the training for n epochs with mini-bachtes of size equal to batch_size
125126 class(network_type), intent (in out ) :: self
126127 integer (ik),intent (in ),optional :: epochs,batch_size
127128 real (rk), intent (in ) :: x(:,:), y(:,:), eta
@@ -130,6 +131,8 @@ subroutine fit_batch(self, x, y, eta,epochs,batch_size)
130131 integer (ik):: num_epochs,num_batch_size
131132 integer (ik):: batch_start,batch_end
132133
134+ real (rk):: pos
135+
133136 nsamples= size (y,dim= 2 )
134137
135138 num_epochs= 1
@@ -141,11 +144,13 @@ subroutine fit_batch(self, x, y, eta,epochs,batch_size)
141144 nbatch= nsamples/ num_batch_size
142145
143146 epoch: do n= 1 ,num_epochs
144- batch_end= 0
145147 mini_batches: do i= 1 ,nbatch
146- batch_start= batch_end+1
148+
149+ ! pull a random mini-batch from the dataset
150+ call random_number (pos)
151+ batch_start= int (pos* (nsamples- num_batch_size+1 ))
152+ if (batch_start.eq. 0 )batch_start= 1
147153 batch_end= batch_start+ batch_size-1
148- if (i.eq. nbatch)batch_end= nsamples
149154
150155 call self% train(x(:,batch_start:batch_end),y(:,batch_start:batch_end),eta)
151156
0 commit comments