@@ -23,18 +23,21 @@ program test_cnn_from_keras
2323 real , allocatable :: training_images(:,:), training_labels(:)
2424 real , allocatable :: validation_images(:,:), validation_labels(:)
2525 real , allocatable :: testing_images(:,:), testing_labels(:)
26+ real , allocatable :: input_reshaped(:,:,:,:)
2627 real :: acc
2728
2829 call load_mnist(training_images, training_labels, &
2930 validation_images, validation_labels, &
3031 testing_images, testing_labels)
3132
32- acc = accuracy(net, reshape (testing_images, shape = [ 1 , 28 , 28 , 10000 ]), label_digits(testing_labels))
33- print * , acc
33+ ! Use only the first 1000 images to make the test short
34+ input_reshaped = reshape (testing_images(:,: 1000 ), shape = [ 1 , 28 , 28 , 1000 ])
3435
35- if (acc < 0.94 ) then
36+ acc = accuracy(net, input_reshaped, label_digits(testing_labels(:1000 )))
37+
38+ if (acc < 0.97 ) then
3639 write (stderr, ' (a)' ) &
37- ' Pre-trained network accuracy should be > 0.94 .. failed'
40+ ' Pre-trained network accuracy should be > 0.97 .. failed'
3841 ok = .false.
3942 end if
4043
@@ -55,12 +58,12 @@ real function accuracy(net, x, y)
5558 real , intent (in ) :: x(:,:,:,:), y(:,:)
5659 integer :: i, good
5760 good = 0
58- do i = 1 , size (x, dim= 2 )
61+ do i = 1 , size (x, dim= 4 )
5962 if (all (maxloc (net % output(x(:,:,:,i))) == maxloc (y(:,i)))) then
6063 good = good + 1
6164 end if
6265 end do
63- accuracy = real (good) / size (x, dim= 2 )
66+ accuracy = real (good) / size (x, dim= 4 )
6467 end function accuracy
6568
6669end program test_cnn_from_keras
0 commit comments