|
| 1 | +program mnist_from_keras |
| 2 | + |
| 3 | + ! This example demonstrates loading a pre-trained MNIST model from Keras |
| 4 | + ! from an HDF5 file and running an inferrence on the testing dataset. |
| 5 | + |
| 6 | + use nf, only: network, label_digits, load_mnist |
| 7 | + use nf_datasets, only: download_and_unpack, keras_model_dense_mnist_url |
| 8 | + |
| 9 | + implicit none |
| 10 | + |
| 11 | + type(network) :: net |
| 12 | + real, allocatable :: training_images(:,:), training_labels(:) |
| 13 | + real, allocatable :: validation_images(:,:), validation_labels(:) |
| 14 | + real, allocatable :: testing_images(:,:), testing_labels(:) |
| 15 | + character(*), parameter :: test_data_path = 'keras_dense_mnist.h5' |
| 16 | + logical :: file_exists |
| 17 | + |
| 18 | + inquire(file=test_data_path, exist=file_exists) |
| 19 | + if (.not. file_exists) call download_and_unpack(keras_model_dense_mnist_url) |
| 20 | + |
| 21 | + call load_mnist(training_images, training_labels, & |
| 22 | + validation_images, validation_labels, & |
| 23 | + testing_images, testing_labels) |
| 24 | + |
| 25 | + print '("Loading a pre-trained MNIST model from Keras")' |
| 26 | + print '(60("="))' |
| 27 | + |
| 28 | + net = network(test_data_path) |
| 29 | + |
| 30 | + call net % print_info() |
| 31 | + |
| 32 | + if (this_image() == 1) & |
| 33 | + print '(a,f5.2,a)', 'Accuracy: ', accuracy( & |
| 34 | + net, testing_images, label_digits(testing_labels)) * 100, ' %' |
| 35 | + |
| 36 | +contains |
| 37 | + |
| 38 | + real function accuracy(net, x, y) |
| 39 | + type(network), intent(in out) :: net |
| 40 | + real, intent(in) :: x(:,:), y(:,:) |
| 41 | + integer :: i, good |
| 42 | + good = 0 |
| 43 | + do i = 1, size(x, dim=2) |
| 44 | + if (all(maxloc(net % output(x(:,i))) == maxloc(y(:,i)))) then |
| 45 | + good = good + 1 |
| 46 | + end if |
| 47 | + end do |
| 48 | + accuracy = real(good) / size(x, dim=2) |
| 49 | + end function accuracy |
| 50 | + |
| 51 | +end program mnist_from_keras |
0 commit comments