Skip to content

Commit 69237b3

Browse files
committed
Add example about loading a Keras model from file
1 parent 1648f92 commit 69237b3

File tree

3 files changed

+55
-2
lines changed

3 files changed

+55
-2
lines changed

CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,13 +106,13 @@ string(REGEX REPLACE "^ | $" "" LIBS "${LIBS}")
106106

107107
# tests
108108
enable_testing()
109-
foreach(execid input1d_layer input3d_layer dense_layer conv2d_layer maxpool2d_layer flatten_layer dense_network conv2d_network io_hdf5 keras_read_model)
109+
foreach(execid input1d_layer input3d_layer dense_layer conv2d_layer maxpool2d_layer flatten_layer dense_network dense_network_from_keras conv2d_network io_hdf5 keras_read_model)
110110
add_executable(test_${execid} test/test_${execid}.f90)
111111
target_link_libraries(test_${execid} neural ${LIBS})
112112
add_test(test_${execid} bin/test_${execid})
113113
endforeach()
114114

115-
foreach(execid cnn mnist simple sine)
115+
foreach(execid cnn mnist mnist_from_keras simple sine)
116116
add_executable(${execid} example/${execid}.f90)
117117
target_link_libraries(${execid} neural ${LIBS})
118118
endforeach()

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,8 @@ examples, in increasing level of complexity:
180180
dataset
181181
4. [cnn](example/cnn.f90): Creating and running forward a simple CNN using
182182
`input`, `conv2d`, `maxpool2d`, `flatten`, and `dense` layers.
183+
5. [mnist_from_keras](example/mnist_from_keras.f90): Creating a pre-trained
184+
model from a Keras HDF5 file.
183185

184186
The examples also show you the extent of the public API that's meant to be
185187
used in applications, i.e. anything from the `nf` module.

example/mnist_from_keras.f90

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
program mnist_from_keras
2+
3+
! This example demonstrates loading a pre-trained MNIST model from Keras
4+
! from an HDF5 file and running an inferrence on the testing dataset.
5+
6+
use nf, only: network, label_digits, load_mnist
7+
use nf_datasets, only: download_and_unpack, keras_model_dense_mnist_url
8+
9+
implicit none
10+
11+
type(network) :: net
12+
real, allocatable :: training_images(:,:), training_labels(:)
13+
real, allocatable :: validation_images(:,:), validation_labels(:)
14+
real, allocatable :: testing_images(:,:), testing_labels(:)
15+
character(*), parameter :: test_data_path = 'keras_dense_mnist.h5'
16+
logical :: file_exists
17+
18+
inquire(file=test_data_path, exist=file_exists)
19+
if (.not. file_exists) call download_and_unpack(keras_model_dense_mnist_url)
20+
21+
call load_mnist(training_images, training_labels, &
22+
validation_images, validation_labels, &
23+
testing_images, testing_labels)
24+
25+
print '("Loading a pre-trained MNIST model from Keras")'
26+
print '(60("="))'
27+
28+
net = network(test_data_path)
29+
30+
call net % print_info()
31+
32+
if (this_image() == 1) &
33+
print '(a,f5.2,a)', 'Accuracy: ', accuracy( &
34+
net, testing_images, label_digits(testing_labels)) * 100, ' %'
35+
36+
contains
37+
38+
real function accuracy(net, x, y)
39+
type(network), intent(in out) :: net
40+
real, intent(in) :: x(:,:), y(:,:)
41+
integer :: i, good
42+
good = 0
43+
do i = 1, size(x, dim=2)
44+
if (all(maxloc(net % output(x(:,i))) == maxloc(y(:,i)))) then
45+
good = good + 1
46+
end if
47+
end do
48+
accuracy = real(good) / size(x, dim=2)
49+
end function accuracy
50+
51+
end program mnist_from_keras

0 commit comments

Comments
 (0)