Skip to content

Commit ef9efbc

Browse files
committed
make loading of MNIST validation set optional
1 parent 3a569b0 commit ef9efbc

File tree

3 files changed

+10
-11
lines changed

3 files changed

+10
-11
lines changed

README.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,6 @@ program example_mnist
238238
239239
real(rk), allocatable :: tr_images(:,:), tr_labels(:)
240240
real(rk), allocatable :: te_images(:,:), te_labels(:)
241-
real(rk), allocatable :: va_images(:,:), va_labels(:)
242241
real(rk), allocatable :: input(:,:), output(:,:)
243242
244243
type(network_type) :: net
@@ -247,8 +246,7 @@ program example_mnist
247246
integer(ik) :: batch_size, batch_start, batch_end
248247
real(rk) :: pos
249248
250-
call load_mnist(tr_images, tr_labels, te_images,&
251-
te_labels, va_images, va_labels)
249+
call load_mnist(tr_images, tr_labels, te_images, te_labels)
252250
253251
net = network_type([784, 30, 10])
254252

src/lib/mod_mnist.f90

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ subroutine load_mnist(tr_images, tr_labels, te_images,&
4343
! Loads the MNIST dataset into arrays.
4444
real(rk), allocatable, intent(in out) :: tr_images(:,:), tr_labels(:)
4545
real(rk), allocatable, intent(in out) :: te_images(:,:), te_labels(:)
46-
real(rk), allocatable, intent(in out) :: va_images(:,:), va_labels(:)
46+
real(rk), allocatable, intent(in out), optional :: va_images(:,:), va_labels(:)
4747
integer(ik), parameter :: dtype = 4, image_size = 784
4848
integer(ik), parameter :: tr_nimages = 50000
4949
integer(ik), parameter :: te_nimages = 10000
@@ -59,10 +59,12 @@ subroutine load_mnist(tr_images, tr_labels, te_images,&
5959
call read_binary_file('../data/mnist/mnist_testing_labels.dat',&
6060
dtype, te_nimages, te_labels)
6161

62-
call read_binary_file('../data/mnist/mnist_validation_images.dat',&
63-
dtype, image_size, va_nimages, va_images)
64-
call read_binary_file('../data/mnist/mnist_validation_labels.dat',&
65-
dtype, va_nimages, va_labels)
62+
if (present(va_images) .and. present(va_labels)) then
63+
call read_binary_file('../data/mnist/mnist_validation_images.dat',&
64+
dtype, image_size, va_nimages, va_images)
65+
call read_binary_file('../data/mnist/mnist_validation_labels.dat',&
66+
dtype, va_nimages, va_labels)
67+
end if
6668

6769
end subroutine load_mnist
6870

src/tests/example_mnist.f90

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ program example_mnist
1212

1313
real(rk), allocatable :: tr_images(:,:), tr_labels(:)
1414
real(rk), allocatable :: te_images(:,:), te_labels(:)
15-
real(rk), allocatable :: va_images(:,:), va_labels(:)
15+
!real(rk), allocatable :: va_images(:,:), va_labels(:)
1616
real(rk), allocatable :: input(:,:), output(:,:)
1717

1818
type(network_type) :: net
@@ -21,8 +21,7 @@ program example_mnist
2121
integer(ik) :: batch_size, batch_start, batch_end
2222
real(rk) :: pos
2323

24-
call load_mnist(tr_images, tr_labels, te_images,&
25-
te_labels, va_images, va_labels)
24+
call load_mnist(tr_images, tr_labels, te_images, te_labels)
2625

2726
net = network_type([784, 10, 10])
2827

0 commit comments

Comments
 (0)