Skip to content

Commit bd643fd

Browse files
committed
feat(io): download archive if .dat files not found
1 parent 6b79d18 commit bd643fd

File tree

2 files changed

+54
-8
lines changed

2 files changed

+54
-8
lines changed

src/mod_io_submodule.f90

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,52 @@
22

33
implicit none
44

5+
integer, parameter :: message_len = 128
56
contains
67

8+
subroutine download_and_uncompress()
9+
character(len=*), parameter :: download_mechanism = 'curl -LO '
10+
character(len=*), parameter :: base_url='https://github.com/modern-fortran/neural-fortran/files/8498876/'
11+
character(len=*), parameter :: download_filename = 'mnist.tar.gz'
12+
character(len=*), parameter :: download_command = download_mechanism // base_url //download_filename
13+
character(len=*), parameter :: uncompress_file= 'tar xvzf ' // download_filename
14+
character(len=message_len) :: command_message
15+
character(len=:), allocatable :: error_message
16+
integer exit_status, command_status
17+
exit_status=0
18+
call execute_command_line(command=download_command, &
19+
wait=.true., exitstat=exit_status, cmdstat=command_status, cmdmsg=command_message)
20+
if (any([exit_status, command_status]/=0)) then
21+
error_message = 'command "' // download_command // '" failed'
22+
if (command_status/=0) error_message = error_message // " with message " // trim(command_message)
23+
error stop error_message
24+
end if
25+
call execute_command_line(command=uncompress_file , &
26+
wait=.true., exitstat=exit_status, cmdstat=command_status, cmdmsg=command_message)
27+
if (any([exit_status, command_status]/=0)) then
28+
error_message = 'command "' // uncompress_file // '" failed'
29+
if (command_status/=0) error_message = error_message // " with message " // trim(command_message)
30+
error stop error_message
31+
end if
32+
end subroutine
33+
734
module subroutine read_binary_file_1d(filename, dtype, nrec, array)
835
character(len=*), intent(in) :: filename
936
integer(ik), intent(in) :: dtype, nrec
1037
real(rk), allocatable, intent(in out) :: array(:)
1138
integer(ik) :: fileunit
12-
allocate(array(nrec))
39+
character(len=message_len) io_message, command_message
40+
integer io_status
41+
io_status=0
1342
open(newunit=fileunit, file=filename, access='direct',&
14-
action='read', recl=dtype * nrec, status='old')
43+
action='read', recl=dtype * nrec, status='old', iostat=io_status)
44+
if (io_status/=0) then
45+
call download_and_uncompress
46+
open(newunit=fileunit, file=filename, access='direct',&
47+
action='read', recl=dtype * nrec, status='old', iostat=io_status, iomsg=io_message)
48+
if (io_status/=0) error stop trim(io_message)
49+
end if
50+
allocate(array(nrec))
1551
read(fileunit, rec=1) array
1652
close(fileunit)
1753
end subroutine read_binary_file_1d
@@ -21,6 +57,16 @@ module subroutine read_binary_file_2d(filename, dtype, dsize, nrec, array)
2157
integer(ik), intent(in) :: dtype, dsize, nrec
2258
real(rk), allocatable, intent(in out) :: array(:,:)
2359
integer(ik) :: fileunit, i
60+
character(len=message_len) io_message, command_message
61+
integer io_status
62+
open(newunit=fileunit, file=filename, access='direct',&
63+
action='read', recl=dtype * nrec, status='old', iostat=io_status)
64+
if (io_status/=0) then
65+
call download_and_uncompress
66+
open(newunit=fileunit, file=filename, access='direct',&
67+
action='read', recl=dtype * nrec, status='old', iostat=io_status, iomsg=io_message)
68+
if (io_status/=0) error stop trim(io_message)
69+
end if
2470
allocate(array(dsize, nrec))
2571
open(newunit=fileunit, file=filename, access='direct',&
2672
action='read', recl=dtype * dsize, status='old')

src/mod_mnist_submodule.f90

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,20 +43,20 @@ module subroutine load_mnist(tr_images, tr_labels, te_images,&
4343
integer(ik), parameter :: te_nimages = 10000
4444
integer(ik), parameter :: va_nimages = 10000
4545

46-
call read_binary_file('data/mnist/mnist_training_images.dat',&
46+
call read_binary_file('mnist_training_images.dat',&
4747
dtype, image_size, tr_nimages, tr_images)
48-
call read_binary_file('data/mnist/mnist_training_labels.dat',&
48+
call read_binary_file('mnist_training_labels.dat',&
4949
dtype, tr_nimages, tr_labels)
5050

51-
call read_binary_file('data/mnist/mnist_testing_images.dat',&
51+
call read_binary_file('mnist_testing_images.dat',&
5252
dtype, image_size, te_nimages, te_images)
53-
call read_binary_file('data/mnist/mnist_testing_labels.dat',&
53+
call read_binary_file('mnist_testing_labels.dat',&
5454
dtype, te_nimages, te_labels)
5555

5656
if (present(va_images) .and. present(va_labels)) then
57-
call read_binary_file('data/mnist/mnist_validation_images.dat',&
57+
call read_binary_file('mnist_validation_images.dat',&
5858
dtype, image_size, va_nimages, va_images)
59-
call read_binary_file('data/mnist/mnist_validation_labels.dat',&
59+
call read_binary_file('mnist_validation_labels.dat',&
6060
dtype, va_nimages, va_labels)
6161
end if
6262

0 commit comments

Comments
 (0)