Skip to content

Commit 16f6eea

Browse files
author
Vandenplas, Jeremie
committed
addition of the methods fit and predict + simplification of example mnist
1 parent ef9efbc commit 16f6eea

File tree

2 files changed

+58
-26
lines changed

2 files changed

+58
-26
lines changed

src/lib/mod_network.f90

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,22 @@ module mod_network
2626

2727
procedure, public, pass(self) :: accuracy
2828
procedure, public, pass(self) :: backprop
29+
procedure, public, pass(self) :: fit_batch
2930
procedure, public, pass(self) :: fwdprop
3031
procedure, public, pass(self) :: init
3132
procedure, public, pass(self) :: load
3233
procedure, public, pass(self) :: loss
3334
procedure, public, pass(self) :: output
35+
procedure, public, pass(self) :: predict_batch
3436
procedure, public, pass(self) :: save
3537
procedure, public, pass(self) :: set_activation
3638
procedure, public, pass(self) :: sync
3739
procedure, public, pass(self) :: train_batch
3840
procedure, public, pass(self) :: train_single
3941
procedure, public, pass(self) :: update
4042

43+
generic, public :: fit => fit_batch!, train_single
44+
generic, public :: predict => predict_batch!, train_single
4145
generic, public :: train => train_batch, train_single
4246

4347
end type network_type
@@ -117,6 +121,39 @@ pure subroutine backprop(self, y, dw, db)
117121

118122
end subroutine backprop
119123

124+
subroutine fit_batch(self, x, y, eta,epochs,batch_size)
125+
class(network_type), intent(in out) :: self
126+
integer(ik),intent(in),optional::epochs,batch_size
127+
real(rk), intent(in) :: x(:,:), y(:,:), eta
128+
129+
integer(ik)::i,n,nsamples,nbatch
130+
integer(ik)::num_epochs,num_batch_size
131+
integer(ik)::batch_start,batch_end
132+
133+
nsamples=size(y,dim=2)
134+
135+
num_epochs=1
136+
if(present(epochs))num_epochs=epochs
137+
138+
num_batch_size=nsamples
139+
if(present(batch_size))num_batch_size=batch_size
140+
141+
nbatch=nsamples/num_batch_size
142+
143+
epoch: do n=1,num_epochs
144+
batch_end=0
145+
mini_batches: do i=1,nbatch
146+
batch_start=batch_end+1
147+
batch_end=batch_start+batch_size-1
148+
if(i.eq.nbatch)batch_end=nsamples
149+
150+
call self%train(x(:,batch_start:batch_end),y(:,batch_start:batch_end),eta)
151+
152+
enddo mini_batches
153+
enddo epoch
154+
155+
end subroutine
156+
120157
pure subroutine fwdprop(self, x)
121158
! Performs the forward propagation and stores arguments to activation
122159
! functions and activations themselves for use in backprop.
@@ -188,6 +225,20 @@ pure function output(self, x) result(a)
188225
end associate
189226
end function output
190227

228+
pure function predict_batch(self, x) result(a)
229+
class(network_type), intent(in) :: self
230+
real(rk), intent(in) :: x(:,:)
231+
real(rk), allocatable :: a(:,:)
232+
233+
integer(ik) :: i
234+
235+
allocate(a(self%dims(size(self%dims)),size(x,dim=2)))
236+
do i = 1, size(x, dim=2)
237+
a(:,i)=self%output(x(:,i))
238+
enddo
239+
240+
end function predict_batch
241+
191242
subroutine save(self, filename)
192243
! Saves the network to a file.
193244
class(network_type), intent(in out) :: self

src/tests/example_mnist.f90

Lines changed: 7 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,11 @@ program example_mnist
1212

1313
real(rk), allocatable :: tr_images(:,:), tr_labels(:)
1414
real(rk), allocatable :: te_images(:,:), te_labels(:)
15-
!real(rk), allocatable :: va_images(:,:), va_labels(:)
16-
real(rk), allocatable :: input(:,:), output(:,:)
1715

1816
type(network_type) :: net
1917

2018
integer(ik) :: i, n, num_epochs
21-
integer(ik) :: batch_size, batch_start, batch_end
22-
real(rk) :: pos
19+
integer(ik) :: batch_size
2320

2421
call load_mnist(tr_images, tr_labels, te_images, te_labels)
2522

@@ -33,28 +30,12 @@ program example_mnist
3330
net % accuracy(te_images, label_digits(te_labels)) * 100, ' %'
3431
end if
3532

36-
epochs: do n = 1, num_epochs
37-
mini_batches: do i = 1, size(tr_labels) / batch_size
38-
39-
! pull a random mini-batch from the dataset
40-
call random_number(pos)
41-
batch_start = int(pos * (size(tr_labels) - batch_size + 1))
42-
batch_end = batch_start + batch_size - 1
43-
44-
! prepare mini-batch
45-
input = tr_images(:,batch_start:batch_end)
46-
output = label_digits(tr_labels(batch_start:batch_end))
47-
48-
! train the network on the mini-batch
49-
call net % train(input, output, eta=3._rk)
50-
51-
end do mini_batches
52-
53-
if (this_image() == 1) then
54-
write(*, '(a,i2,a,f5.2,a)') 'Epoch ', n, ' done, Accuracy: ',&
55-
net % accuracy(te_images, label_digits(te_labels)) * 100, ' %'
56-
end if
33+
call net%fit(tr_images,label_digits(tr_labels),eta=3._rk,epochs=num_epochs,batch_size=batch_size)
34+
35+
if (this_image() == 1) then
36+
write(*, '(a,f5.2,a)') 'Epochs done, Accuracy: ',&
37+
net % accuracy(te_images, label_digits(te_labels)) * 100, ' %'
38+
endif
5739

58-
end do epochs
5940

6041
end program example_mnist

0 commit comments

Comments
 (0)