Skip to content

Commit a65df50

Browse files
author
Vandenplas, Jeremie
committed
mini-batches made at random
1 parent 305c5de commit a65df50

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

src/lib/mod_network.f90

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ pure subroutine backprop(self, y, dw, db)
122122
end subroutine backprop
123123

124124
subroutine fit_batch(self, x, y, eta,epochs,batch_size)
125+
!Performs the training for n epochs with mini-bachtes of size equal to batch_size
125126
class(network_type), intent(in out) :: self
126127
integer(ik),intent(in),optional::epochs,batch_size
127128
real(rk), intent(in) :: x(:,:), y(:,:), eta
@@ -130,6 +131,8 @@ subroutine fit_batch(self, x, y, eta,epochs,batch_size)
130131
integer(ik)::num_epochs,num_batch_size
131132
integer(ik)::batch_start,batch_end
132133

134+
real(rk)::pos
135+
133136
nsamples=size(y,dim=2)
134137

135138
num_epochs=1
@@ -141,11 +144,13 @@ subroutine fit_batch(self, x, y, eta,epochs,batch_size)
141144
nbatch=nsamples/num_batch_size
142145

143146
epoch: do n=1,num_epochs
144-
batch_end=0
145147
mini_batches: do i=1,nbatch
146-
batch_start=batch_end+1
148+
149+
!pull a random mini-batch from the dataset
150+
call random_number(pos)
151+
batch_start=int(pos*(nsamples-num_batch_size+1))
152+
if(batch_start.eq.0)batch_start=1
147153
batch_end=batch_start+batch_size-1
148-
if(i.eq.nbatch)batch_end=nsamples
149154

150155
call self%train(x(:,batch_start:batch_end),y(:,batch_start:batch_end),eta)
151156

0 commit comments

Comments
 (0)