@@ -3,8 +3,6 @@ module mod_mnist
33 ! ! Procedures to work with MNIST dataset, usable with data format
44 ! ! as provided in this repo and not the original data format (idx).
55
6- use iso_fortran_env, only: real32 ! ! TODO make MNIST work with arbitrary precision
7- use mod_io, only: read_binary_file
86 use mod_kinds, only: ik, rk
97
108 implicit none
@@ -13,75 +11,33 @@ module mod_mnist
1311
1412 public :: label_digits, load_mnist, print_image
1513
16- contains
17-
18- pure function digits (x )
19- ! ! Returns an array of 10 reals, with zeros everywhere
20- ! ! and a one corresponding to the input number, for example:
21- ! ! digits(0) = [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]
22- ! ! digits(1) = [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.]
23- ! ! digits(6) = [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]
24- real (rk), intent (in ) :: x
25- real (rk) :: digits (10 )
26- digits = 0
27- digits (int (x + 1 )) = 1
28- end function digits
29-
30- pure function label_digits (labels ) result(res)
31- ! ! Converts an array of MNIST labels into a form
32- ! ! that can be input to the network_type instance.
33- real (rk), intent (in ) :: labels(:)
34- real (rk) :: res(10 , size (labels))
35- integer (ik) :: i
36- do i = 1 , size (labels)
37- res(:,i) = digits (labels(i))
38- end do
39- end function label_digits
40-
41- subroutine load_mnist (tr_images , tr_labels , te_images ,&
42- te_labels , va_images , va_labels )
43- ! ! Loads the MNIST dataset into arrays.
44- real (rk), allocatable , intent (in out ) :: tr_images(:,:), tr_labels(:)
45- real (rk), allocatable , intent (in out ) :: te_images(:,:), te_labels(:)
46- real (rk), allocatable , intent (in out ), optional :: va_images(:,:), va_labels(:)
47- integer (ik), parameter :: dtype = 4 , image_size = 784
48- integer (ik), parameter :: tr_nimages = 50000
49- integer (ik), parameter :: te_nimages = 10000
50- integer (ik), parameter :: va_nimages = 10000
51-
52- call read_binary_file(' data/mnist/mnist_training_images.dat' ,&
53- dtype, image_size, tr_nimages, tr_images)
54- call read_binary_file(' data/mnist/mnist_training_labels.dat' ,&
55- dtype, tr_nimages, tr_labels)
56-
57- call read_binary_file(' data/mnist/mnist_testing_images.dat' ,&
58- dtype, image_size, te_nimages, te_images)
59- call read_binary_file(' data/mnist/mnist_testing_labels.dat' ,&
60- dtype, te_nimages, te_labels)
61-
62- if (present (va_images) .and. present (va_labels)) then
63- call read_binary_file(' data/mnist/mnist_validation_images.dat' ,&
64- dtype, image_size, va_nimages, va_images)
65- call read_binary_file(' data/mnist/mnist_validation_labels.dat' ,&
66- dtype, va_nimages, va_labels)
67- end if
68-
69- end subroutine load_mnist
70-
71- subroutine print_image (images , labels , n )
72- ! ! Prints a single image and label to screen.
73- real (rk), intent (in ) :: images(:,:), labels(:)
74- integer (ik), intent (in ) :: n
75- real (rk) :: image(28 , 28 )
76- character (len= 1 ) :: char_image(28 , 28 )
77- integer (ik) i, j
78- image = reshape (images(:,n), [28 , 28 ])
79- char_image = ' .'
80- where (image > 0 ) char_image = ' #'
81- print * , labels(n)
82- do j = 1 , 28
83- print * , char_image(:,j)
84- end do
85- end subroutine print_image
14+ interface
15+
16+ pure module function label_digits(labels) result(res)
17+ ! ! Converts an array of MNIST labels into a form
18+ ! ! that can be input to the network_type instance.
19+ implicit none
20+ real (rk), intent (in ) :: labels(:)
21+ real (rk) :: res(10 , size (labels))
22+ end function label_digits
23+
24+ module subroutine load_mnist (tr_images , tr_labels , te_images ,&
25+
26+ te_labels , va_images , va_labels )
27+ ! ! Loads the MNIST dataset into arrays.
28+ implicit none
29+ real (rk), allocatable , intent (in out ) :: tr_images(:,:), tr_labels(:)
30+ real (rk), allocatable , intent (in out ) :: te_images(:,:), te_labels(:)
31+ real (rk), allocatable , intent (in out ), optional :: va_images(:,:), va_labels(:)
32+ end subroutine load_mnist
33+
34+ module subroutine print_image (images , labels , n )
35+ ! ! Prints a single image and label to screen.
36+ implicit none
37+ real (rk), intent (in ) :: images(:,:), labels(:)
38+ integer (ik), intent (in ) :: n
39+ end subroutine print_image
40+
41+ end interface
8642
8743end module mod_mnist
0 commit comments