Skip to content

Commit 1ea7bc3

Browse files
committed
tidying up
1 parent 4a944e7 commit 1ea7bc3

File tree

2 files changed

+21
-23
lines changed

2 files changed

+21
-23
lines changed

src/lib/mod_network.f90

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ end function loss
164164

165165
pure function output_single(self, x) result(a)
166166
! Use forward propagation to compute the output of the network.
167+
! This specific procedure is for a single sample of 1-d input data.
167168
class(network_type), intent(in) :: self
168169
real(rk), intent(in) :: x(:)
169170
real(rk), allocatable :: a(:)
@@ -177,17 +178,16 @@ pure function output_single(self, x) result(a)
177178
end function output_single
178179

179180
pure function output_batch(self, x) result(a)
181+
! Use forward propagation to compute the output of the network.
182+
! This specific procedure is for a batch of 1-d input data.
180183
class(network_type), intent(in) :: self
181184
real(rk), intent(in) :: x(:,:)
182185
real(rk), allocatable :: a(:,:)
183-
184186
integer(ik) :: i
185-
186-
allocate(a(self%dims(size(self%dims)),size(x,dim=2)))
187+
allocate(a(self % dims(size(self % dims)), size(x, dim=2)))
187188
do i = 1, size(x, dim=2)
188-
a(:,i)=self%output(x(:,i))
189-
enddo
190-
189+
a(:,i) = self % output_single(x(:,i))
190+
end do
191191
end function output_batch
192192

193193
subroutine save(self, filename)
@@ -272,31 +272,30 @@ subroutine train_batch(self, x, y, eta)
272272

273273
end subroutine train_batch
274274

275-
subroutine train_epochs(self, x, y, eta,num_epochs,num_batch_size)
276-
!Performs the training for nun_epochs epochs with mini-bachtes of size equal to num_batch_size
275+
subroutine train_epochs(self, x, y, eta, num_epochs, batch_size)
276+
! Trains for num_epochs epochs with mini-bachtes of size equal to batch_size.
277277
class(network_type), intent(in out) :: self
278-
integer(ik),intent(in)::num_epochs,num_batch_size
278+
integer(ik), intent(in) :: num_epochs, batch_size
279279
real(rk), intent(in) :: x(:,:), y(:,:), eta
280280

281-
integer(ik)::i,n,nsamples,nbatch
282-
integer(ik)::batch_start,batch_end
283-
284-
real(rk)::pos
281+
integer(ik) :: i, n, nsamples, nbatch
282+
integer(ik) :: batch_start, batch_end
285283

286-
nsamples=size(y,dim=2)
284+
real(rk) :: pos
287285

288-
nbatch=nsamples/num_batch_size
286+
nsamples = size(y, dim=2)
287+
nbatch = nsamples / batch_size
289288

290-
epoch: do n=1,num_epochs
291-
mini_batches: do i=1,nbatch
289+
epoch: do n = 1, num_epochs
290+
mini_batches: do i = 1, nbatch
292291

293292
!pull a random mini-batch from the dataset
294293
call random_number(pos)
295-
batch_start=int(pos*(nsamples-num_batch_size+1))
296-
if(batch_start.eq.0)batch_start=1
297-
batch_end=batch_start+num_batch_size-1
294+
batch_start = int(pos * (nsamples - batch_size + 1))
295+
if (batch_start == 0) batch_start = 1
296+
batch_end = batch_start + batch_size - 1
298297

299-
call self%train(x(:,batch_start:batch_end),y(:,batch_start:batch_end),eta)
298+
call self % train(x(:,batch_start:batch_end), y(:,batch_start:batch_end), eta)
300299

301300
enddo mini_batches
302301
enddo epoch

src/tests/example_mnist.f90

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,11 @@ program example_mnist
3030
net % accuracy(te_images, label_digits(te_labels)) * 100, ' %'
3131
end if
3232

33-
call net%train(tr_images,label_digits(tr_labels),3._rk,num_epochs,batch_size)
33+
call net % train(tr_images, label_digits(tr_labels), 3._rk, num_epochs, batch_size)
3434

3535
if (this_image() == 1) then
3636
write(*, '(a,f5.2,a)') 'Epochs done, Accuracy: ',&
3737
net % accuracy(te_images, label_digits(te_labels)) * 100, ' %'
3838
endif
3939

40-
4140
end program example_mnist

0 commit comments

Comments
 (0)