Skip to content

Commit 1648f92

Browse files
committed
Load weights and biases from file and test the accuracy of the pre-trained model
1 parent 7d87dd3 commit 1648f92

File tree

2 files changed

+79
-2
lines changed

2 files changed

+79
-2
lines changed

src/nf/nf_network_submodule.f90

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
use nf_flatten_layer, only: flatten_layer
55
use nf_input1d_layer, only: input1d_layer
66
use nf_input3d_layer, only: input3d_layer
7+
use nf_io_hdf5, only: get_hdf5_dataset
78
use nf_keras, only: get_keras_h5_layers, keras_layer
89
use nf_layer, only: layer
910
use nf_layer_constructors, only: dense, input
@@ -55,6 +56,8 @@ module function network_from_keras(filename) result(res)
5556
type(network) :: res
5657
type(keras_layer), allocatable :: keras_layers(:)
5758
type(layer), allocatable :: layers(:)
59+
character(:), allocatable :: layer_name
60+
character(:), allocatable :: object_name
5861
integer :: n
5962

6063
keras_layers = get_keras_h5_layers(filename)
@@ -89,7 +92,44 @@ module function network_from_keras(filename) result(res)
8992

9093
res = network(layers)
9194

92-
!TODO read weights and biases from Keras file and set here
95+
! Loop over layers and read weights and biases from the Keras h5 file
96+
! for each; currently only dense layers are implemented.
97+
do n = 2, size(res % layers)
98+
99+
layer_name = keras_layers(n) % name
100+
101+
if (keras_layers(n) % class == 'Dense') then
102+
select type(this_layer => res % layers(n) % p)
103+
104+
type is(dense_layer)
105+
106+
! Read biases from file
107+
object_name = '/model_weights/' // layer_name // '/' &
108+
// layer_name // '/bias:0'
109+
call get_hdf5_dataset(filename, object_name, this_layer % biases)
110+
111+
! Read weights from file
112+
object_name = '/model_weights/' // layer_name // '/' &
113+
// layer_name // '/kernel:0'
114+
call get_hdf5_dataset(filename, object_name, this_layer % weights)
115+
116+
! TODO Multidimensional arrays are stored in HDF5 in C-order.
117+
! TODO Here we transpose the array to get to the Fortran order.
118+
! TODO There may be a way to do this without re-allocating.
119+
! TODO It probably doesn't matter much since we do this once.
120+
! TODO Figure it out later.
121+
this_layer % weights = transpose(this_layer % weights)
122+
123+
class default
124+
error stop 'Internal error in network_from_keras(); ' &
125+
// 'mismatch in layer types between the Keras and ' &
126+
// 'neural-fortran model layers.'
127+
128+
end select
129+
130+
end if
131+
132+
end do
93133

94134
end function network_from_keras
95135

test/test_dense_network_from_keras.f90

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ program test_dense_network_from_keras
88

99
type(network) :: net
1010
character(*), parameter :: test_data_path = 'keras_dense_mnist.h5'
11-
1211
logical :: file_exists
1312
logical :: ok = .true.
1413

@@ -52,6 +51,29 @@ program test_dense_network_from_keras
5251
ok = .false.
5352
end if
5453

54+
block
55+
56+
use nf, only: load_mnist, label_digits
57+
58+
real, allocatable :: training_images(:,:), training_labels(:)
59+
real, allocatable :: validation_images(:,:), validation_labels(:)
60+
real, allocatable :: testing_images(:,:), testing_labels(:)
61+
real :: acc
62+
63+
call load_mnist(training_images, training_labels, &
64+
validation_images, validation_labels, &
65+
testing_images, testing_labels)
66+
67+
acc = accuracy(net, testing_images, label_digits(testing_labels))
68+
69+
if (acc < 0.94) then
70+
write(stderr, '(a)') &
71+
'Pre-trained network accuracy should be > 0.94.. failed'
72+
ok = .false.
73+
end if
74+
75+
end block
76+
5577
if (ok) then
5678
print '(a)', 'test_dense_network_from_keras: All tests passed.'
5779
else
@@ -60,4 +82,19 @@ program test_dense_network_from_keras
6082
stop 1
6183
end if
6284

85+
contains
86+
87+
real function accuracy(net, x, y)
88+
type(network), intent(in out) :: net
89+
real, intent(in) :: x(:,:), y(:,:)
90+
integer :: i, good
91+
good = 0
92+
do i = 1, size(x, dim=2)
93+
if (all(maxloc(net % output(x(:,i))) == maxloc(y(:,i)))) then
94+
good = good + 1
95+
end if
96+
end do
97+
accuracy = real(good) / size(x, dim=2)
98+
end function accuracy
99+
63100
end program test_dense_network_from_keras

0 commit comments

Comments
 (0)