Skip to content

Commit 908f76b

Browse files
committed
Update example programs
1 parent 4678135 commit 908f76b

File tree

8 files changed

+140
-155
lines changed

8 files changed

+140
-155
lines changed

example/example_mnist.f90

Lines changed: 0 additions & 55 deletions
This file was deleted.

example/example_mnist_epochs.f90

Lines changed: 0 additions & 36 deletions
This file was deleted.

example/example_save_and_load.f90

Lines changed: 0 additions & 32 deletions
This file was deleted.

example/example_simple.f90

Lines changed: 0 additions & 14 deletions
This file was deleted.

example/example_sine.f90

Lines changed: 0 additions & 18 deletions
This file was deleted.

example/mnist.f90

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
program mnist
2+
use nf, only: dense, input, network
3+
use nf_datasets_mnist, only: label_digits, load_mnist
4+
use nf_optimizers, only: sgd
5+
6+
implicit none
7+
8+
type(network) :: net
9+
real, allocatable :: training_images(:,:), training_labels(:)
10+
real, allocatable :: validation_images(:,:), validation_labels(:)
11+
integer :: n, num_epochs
12+
13+
call load_mnist(training_images, training_labels, &
14+
validation_images, validation_labels)
15+
16+
print '("MNIST")'
17+
print '(60("="))'
18+
19+
net = network([ &
20+
input(784), &
21+
dense(30), &
22+
dense(10) &
23+
])
24+
num_epochs = 10
25+
26+
call net % print_info()
27+
28+
if (this_image() == 1) &
29+
print '(a,f5.2,a)', 'Initial accuracy: ', accuracy( &
30+
net, validation_images, label_digits(validation_labels)) * 100, ' %'
31+
32+
epochs: do n = 1, num_epochs
33+
34+
call net % train( &
35+
training_images, &
36+
label_digits(training_labels), &
37+
batch_size=100, &
38+
epochs=1, &
39+
optimizer=sgd(learning_rate=3.) &
40+
)
41+
42+
if (this_image() == 1) &
43+
print '(a,i2,a,f5.2,a)', 'Epoch ', n, ' done, Accuracy: ', accuracy( &
44+
net, validation_images, label_digits(validation_labels)) * 100, ' %'
45+
46+
end do epochs
47+
48+
contains
49+
50+
real function accuracy(net, x, y)
51+
type(network), intent(in out) :: net
52+
real, intent(in) :: x(:,:), y(:,:)
53+
integer :: i, good
54+
good = 0
55+
do i = 1, size(x, dim=2)
56+
if (all(maxloc(net % output(x(:,i))) == maxloc(y(:,i)))) then
57+
good = good + 1
58+
end if
59+
end do
60+
accuracy = real(good) / size(x, dim=2)
61+
end function accuracy
62+
63+
end program mnist

example/simple.f90

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
program simple
2+
use nf, only: dense, input, network
3+
implicit none
4+
type(network) :: net
5+
real, allocatable :: x(:), y(:)
6+
integer, parameter :: num_iterations = 500
7+
integer :: n
8+
9+
print '("Simple")'
10+
print '(60("="))'
11+
12+
net = network([ &
13+
input(3), &
14+
dense(5), &
15+
dense(2) &
16+
])
17+
18+
call net % print_info()
19+
20+
x = [0.2, 0.4, 0.6]
21+
y = [0.123456, 0.246802]
22+
23+
do n = 0, num_iterations
24+
25+
call net % forward(x)
26+
call net % backward(y)
27+
call net % update(1.)
28+
29+
if (mod(n, 50) == 0) &
30+
print '(i4,2(3x,f8.6))', n, net % output(x)
31+
32+
end do
33+
34+
end program simple

example/sine.f90

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
program sine
2+
use nf, only: dense, input, network
3+
implicit none
4+
type(network) :: net
5+
real :: x(1), y(1)
6+
real, parameter :: pi = 4 * atan(1.)
7+
integer, parameter :: num_iterations = 100000
8+
integer, parameter :: test_size = 30
9+
real :: xtest(test_size), ytest(test_size), ypred(test_size)
10+
integer :: i, n
11+
12+
print '("Sine training")'
13+
print '(60("="))'
14+
15+
net = network([ &
16+
input(1), &
17+
dense(5), &
18+
dense(1) &
19+
])
20+
21+
call net % print_info()
22+
23+
xtest = [((i - 1) * 2 * pi / test_size, i = 1, test_size)]
24+
ytest = (sin(xtest) + 1) / 2
25+
26+
do n = 0, num_iterations
27+
28+
call random_number(x)
29+
x = x * 2 * pi
30+
y = (sin(x) + 1) / 2
31+
32+
call net % forward(x)
33+
call net % backward(y)
34+
call net % update(1.)
35+
36+
if (mod(n, 10000) == 0) then
37+
ypred = [(net % output([xtest(i)]), i = 1, test_size)]
38+
print '(i0,1x,f9.6)', n, sum((ypred - ytest)**2) / size(ypred)
39+
end if
40+
41+
end do
42+
43+
end program sine

0 commit comments

Comments
 (0)