Skip to content

Commit 326a612

Browse files
committed
tidy up contributions from @jvdp1
1 parent 1ea7bc3 commit 326a612

File tree

7 files changed

+82
-308
lines changed

7 files changed

+82
-308
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ foreach(execid mnist network_save network_sync set_activation_function)
7878
add_test(test_${execid} bin/test_${execid})
7979
endforeach()
8080

81-
foreach(execid mnist montesinos_uni montesinos_multi save_and_load simple sine)
81+
foreach(execid mnist mnist_epochs save_and_load simple sine)
8282
add_executable(example_${execid} src/tests/example_${execid}.f90)
8383
target_link_libraries(example_${execid} neural ${LIBS})
8484
add_test(example_${execid} bin/example_${execid})

README.md

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -266,13 +266,11 @@ program example_mnist
266266
batch_size = 100
267267
num_epochs = 10
268268
269-
if (this_image() == 1) then
270-
write(*, '(a,f5.2,a)') 'Initial accuracy: ',&
271-
net % accuracy(te_images, label_digits(te_labels)) * 100, ' %'
272-
end if
269+
if (this_image() == 1) print '(a,f5.2,a)', 'Initial accuracy: ', &
270+
net % accuracy(te_images, label_digits(te_labels)) * 100, ' %'
273271
274272
epochs: do n = 1, num_epochs
275-
mini_batches: do i = 1, size(tr_labels) / batch_size
273+
batches: do i = 1, size(tr_labels) / batch_size
276274
277275
! pull a random mini-batch from the dataset
278276
call random_number(pos)
@@ -286,12 +284,10 @@ program example_mnist
286284
! train the network on the mini-batch
287285
call net % train(input, output, eta=3._rk)
288286
289-
end do mini_batches
287+
end do batches
290288
291-
if (this_image() == 1) then
292-
write(*, '(a,i2,a,f5.2,a)') 'Epoch ', n, ' done, Accuracy: ',&
293-
net % accuracy(te_images, label_digits(te_labels)) * 100, ' %'
294-
end if
289+
if (this_image() == 1) print '(a,i2,a,f5.2,a)', 'Epoch ', n, ' done, Accuracy: ', &
290+
net % accuracy(te_images, label_digits(te_labels)) * 100, ' %'
295291
296292
end do epochs
297293

src/lib/mod_network.f90

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ module mod_network
4040

4141
interface network_type
4242
module procedure :: net_constructor
43-
endinterface network_type
43+
end interface network_type
4444

4545
contains
4646

@@ -286,19 +286,19 @@ subroutine train_epochs(self, x, y, eta, num_epochs, batch_size)
286286
nsamples = size(y, dim=2)
287287
nbatch = nsamples / batch_size
288288

289-
epoch: do n = 1, num_epochs
290-
mini_batches: do i = 1, nbatch
289+
epochs: do n = 1, num_epochs
290+
batches: do i = 1, nbatch
291291

292-
!pull a random mini-batch from the dataset
293-
call random_number(pos)
294-
batch_start = int(pos * (nsamples - batch_size + 1))
295-
if (batch_start == 0) batch_start = 1
296-
batch_end = batch_start + batch_size - 1
292+
!pull a random mini-batch from the dataset
293+
call random_number(pos)
294+
batch_start = int(pos * (nsamples - batch_size + 1))
295+
if (batch_start == 0) batch_start = 1
296+
batch_end = batch_start + batch_size - 1
297297

298-
call self % train(x(:,batch_start:batch_end), y(:,batch_start:batch_end), eta)
298+
call self % train(x(:,batch_start:batch_end), y(:,batch_start:batch_end), eta)
299299

300-
enddo mini_batches
301-
enddo epoch
300+
end do batches
301+
end do epochs
302302

303303
end subroutine train_epochs
304304

src/tests/example_mnist.f90

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,29 +12,44 @@ program example_mnist
1212

1313
real(rk), allocatable :: tr_images(:,:), tr_labels(:)
1414
real(rk), allocatable :: te_images(:,:), te_labels(:)
15+
real(rk), allocatable :: input(:,:), output(:,:)
1516

1617
type(network_type) :: net
1718

1819
integer(ik) :: i, n, num_epochs
19-
integer(ik) :: batch_size
20+
integer(ik) :: batch_size, batch_start, batch_end
21+
real(rk) :: pos
2022

2123
call load_mnist(tr_images, tr_labels, te_images, te_labels)
2224

23-
net = network_type([size(tr_images,dim=1), 10, size(label_digits(te_labels),dim=1)])
25+
net = network_type([784, 30, 10])
2426

25-
batch_size = 1000
27+
batch_size = 100
2628
num_epochs = 10
2729

28-
if (this_image() == 1) then
29-
write(*, '(a,f5.2,a)') 'Initial accuracy: ',&
30+
if (this_image() == 1) print '(a,f5.2,a)', 'Initial accuracy: ', &
31+
net % accuracy(te_images, label_digits(te_labels)) * 100, ' %'
32+
33+
epochs: do n = 1, num_epochs
34+
batches: do i = 1, size(tr_labels) / batch_size
35+
36+
! pull a random mini-batch from the dataset
37+
call random_number(pos)
38+
batch_start = int(pos * (size(tr_labels) - batch_size + 1))
39+
batch_end = batch_start + batch_size - 1
40+
41+
! prepare mini-batch
42+
input = tr_images(:,batch_start:batch_end)
43+
output = label_digits(tr_labels(batch_start:batch_end))
44+
45+
! train the network on the mini-batch
46+
call net % train(input, output, eta=3._rk)
47+
48+
end do batches
49+
50+
if (this_image() == 1) print '(a,i2,a,f5.2,a)', 'Epoch ', n, ' done, Accuracy: ', &
3051
net % accuracy(te_images, label_digits(te_labels)) * 100, ' %'
31-
end if
32-
33-
call net % train(tr_images, label_digits(tr_labels), 3._rk, num_epochs, batch_size)
34-
35-
if (this_image() == 1) then
36-
write(*, '(a,f5.2,a)') 'Epochs done, Accuracy: ',&
37-
net % accuracy(te_images, label_digits(te_labels)) * 100, ' %'
38-
endif
52+
53+
end do epochs
3954

4055
end program example_mnist

src/tests/example_mnist_epochs.f90

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
program example_mnist
2+
3+
! A training example with the MNIST dataset.
4+
! Uses stochastic gradient descent and mini-batch size of 100.
5+
! Can be run in serial or parallel mode without modifications.
6+
7+
use mod_kinds, only: ik, rk
8+
use mod_mnist, only: label_digits, load_mnist
9+
use mod_network, only: network_type
10+
11+
implicit none
12+
13+
real(rk), allocatable :: tr_images(:,:), tr_labels(:)
14+
real(rk), allocatable :: te_images(:,:), te_labels(:)
15+
16+
type(network_type) :: net
17+
18+
integer(ik) :: i, n, num_epochs
19+
integer(ik) :: batch_size
20+
21+
call load_mnist(tr_images, tr_labels, te_images, te_labels)
22+
23+
net = network_type([size(tr_images, dim=1), 10, size(label_digits(tr_labels), dim=1)])
24+
25+
batch_size = 100
26+
num_epochs = 10
27+
28+
if (this_image() == 1) print '(a,f5.2,a)', 'Initial accuracy: ', &
29+
net % accuracy(te_images, label_digits(te_labels)) * 100, ' %'
30+
31+
call net % train(tr_images, label_digits(tr_labels), 3._rk, num_epochs, batch_size)
32+
33+
if (this_image() == 1) print '(a,f5.2,a)', 'Epochs done, Accuracy: ', &
34+
net % accuracy(te_images, label_digits(te_labels)) * 100, ' %'
35+
36+
end program example_mnist

src/tests/example_montesinos_multi.f90

Lines changed: 0 additions & 136 deletions
This file was deleted.

0 commit comments

Comments
 (0)