55#:set KINDS_TYPES = REAL_KINDS_TYPES + INT_KINDS_TYPES + CMPLX_KINDS_TYPES
66
77!> Implementation of loading npy files into multidimensional arrays
8- submodule (stdlib_io_np) stdlib_io_npy_load
9- use stdlib_error, only : error_stop
10- use stdlib_strings, only : to_string, starts_with
8+ submodule(stdlib_io_np) stdlib_io_np_load
9+ use stdlib_error, only: error_stop
10+ use stdlib_strings, only: to_string, starts_with
11+ use stdlib_string_type, only: string_type
12+ use stdlib_io_zip, only: unzip, zip_prefix, zip_suffix, t_unzipped_bundle, t_unzipped_file
13+ use stdlib_array
1114 implicit none
1215
1316contains
@@ -33,28 +36,12 @@ contains
3336
3437 open(newunit=io, file=filename, form="unformatted", access="stream", iostat=stat)
3538 catch: block
36- character(len=:), allocatable :: this_type
3739 integer, allocatable :: vshape(:)
3840
39- call get_descriptor (io, filename, this_type , vshape, stat, msg)
41+ call verify_npy_file (io, filename, vtype , vshape, rank , stat, msg)
4042 if (stat /= 0) exit catch
4143
42- if (this_type /= vtype) then
43- stat = 1
44- msg = "File '"//filename//"' contains data of type '"//this_type//"', "//&
45- & "but expected '"//vtype//"'"
46- exit catch
47- end if
48-
49- if (size(vshape) /= rank) then
50- stat = 1
51- msg = "File '"//filename//"' contains data of rank "//&
52- & to_string(size(vshape))//", but expected "//&
53- & to_string(rank)
54- exit catch
55- end if
56-
57- call allocator(array, vshape, stat)
44+ call allocate_array(array, vshape, stat)
5845 if (stat /= 0) then
5946 msg = "Failed to allocate array of type '"//vtype//"' "//&
6047 & "with total size of "//to_string(product(vshape))
@@ -76,30 +63,210 @@ contains
7663 end if
7764
7865 if (present(iomsg).and.allocated(msg)) call move_alloc(msg, iomsg)
79- contains
66+ end
67+ #:endfor
68+ #:endfor
8069
81- !> Wrapped intrinsic allocate to create an allocation from a shape array
82- subroutine allocator(array, vshape, stat)
83- !> Instance of the array to be allocated
84- ${t1}$, allocatable, intent(out) :: array${ranksuffix(rank)}$
85- !> Dimensions to allocate for
86- integer, intent(in) :: vshape(:)
87- !> Status of allocate
70+ !> Verify header, type and rank of the npy file.
71+ subroutine verify_npy_file(io, filename, vtype, vshape, rank, stat, msg)
72+ !> Access unit to the npy file.
73+ integer, intent(in) :: io
74+ !> Name of the npy file to load from.
75+ character(len=*), intent(in) :: filename
76+ !> Type of the data stored, retrieved from field `descr`.
77+ character(len=*), intent(in) :: vtype
78+ !> Shape of the stored data, retrieved from field `shape`.
79+ integer, allocatable, intent(out) :: vshape(:)
80+ !> Expected rank of the data.
81+ integer, intent(in) :: rank
82+ !> Status of operation.
8883 integer, intent(out) :: stat
84+ !> Associated error message in case of non-zero status.
85+ character(len=:), allocatable, intent(out) :: msg
86+
87+ character(len=:), allocatable :: this_type
88+
89+ call get_descriptor(io, filename, this_type, vshape, stat, msg)
90+ if (stat /= 0) return
91+
92+ if (this_type /= vtype) then
93+ stat = 1
94+ msg = "File '"//filename//"' contains data of type '"//this_type//"', "//&
95+ & "but expected '"//vtype//"'"
96+ return
97+ end if
98+
99+ if (size(vshape) /= rank) then
100+ stat = 1
101+ msg = "File '"//filename//"' contains data of rank "//&
102+ & to_string(size(vshape))//", but expected "//&
103+ & to_string(rank)
104+ return
105+ end if
106+ end
107+
108+ #:for k1, t1 in KINDS_TYPES
109+ #:for rank in RANKS
110+ module subroutine allocate_array_${t1[0]}$${k1}$_${rank}$(array, vshape, stat)
111+ ${t1}$, allocatable, intent(out) :: array${ranksuffix(rank)}$
112+ integer, intent(in) :: vshape(:)
113+ integer, intent(out) :: stat
114+
115+ allocate(array( &
116+ #:for i in range(rank-1)
117+ & vshape(${i+1}$), &
118+ #:endfor
119+ & vshape(${rank}$)), &
120+ & stat=stat)
121+ end
122+ #:endfor
123+ #:endfor
124+
125+ !> Version: experimental
126+ !>
127+ !> Load multidimensional arrays from a compressed or uncompressed npz file.
128+ !> ([Specification](../page/specs/stdlib_io.html#load_npz))
129+ module subroutine load_npz_to_bundle(filename, array_bundle, iostat, iomsg)
130+ character(len=*), intent(in) :: filename
131+ type(t_array_bundle), intent(out) :: array_bundle
132+ integer, intent(out), optional :: iostat
133+ character(len=:), allocatable, intent(out), optional :: iomsg
89134
90- allocate(array( &
91- #:for i in range(rank-1)
92- & vshape(${i+1}$), &
135+ type(t_unzipped_bundle) :: unzipped_bundle
136+ integer :: stat
137+ character(len=:), allocatable :: msg
138+
139+ call unzip(filename, unzipped_bundle, stat, msg)
140+ if (stat == 0) then
141+ call load_raw_to_bundle(unzipped_bundle, array_bundle, stat, msg)
142+ else
143+ call identify_problem(filename, stat, msg)
144+ end if
145+
146+ if (present(iostat)) then
147+ iostat = stat
148+ else if (stat /= 0) then
149+ if (allocated(msg)) then
150+ call error_stop("Failed to read arrays from file '"//filename//"'"//nl//msg)
151+ else
152+ call error_stop("Failed to read arrays from file '"//filename//"'")
153+ end if
154+ end if
155+
156+ if (present(iomsg) .and. allocated(msg)) call move_alloc(msg, iomsg)
157+ end
158+
159+ module subroutine load_raw_to_bundle(unzipped_bundle, array_bundle, stat, msg)
160+ type(t_unzipped_bundle), intent(in) :: unzipped_bundle
161+ type(t_array_bundle), intent(out) :: array_bundle
162+ integer, intent(out) :: stat
163+ character(len=:), allocatable, intent(out) :: msg
164+
165+ integer :: i, io
166+
167+ allocate (array_bundle%files(size(unzipped_bundle%files)))
168+ do i = 1, size(unzipped_bundle%files)
169+ array_bundle%files(i)%name = unzipped_bundle%files(i)%name
170+ open (newunit=io, status='scratch', form='unformatted', access='stream', iostat=stat)
171+ if (stat /= 0) return
172+ write (io) unzipped_bundle%files(i)%data
173+ call load_string_to_array(io, unzipped_bundle%files(i), array_bundle%files(i), stat, msg)
174+ close (io, status='delete', iostat=stat)
175+ if (stat /= 0) return
176+ end do
177+ end
178+
179+ module subroutine load_string_to_array(io, unzipped_file, array, stat, msg)
180+ integer, intent(in) :: io
181+ type(t_unzipped_file), intent(in) :: unzipped_file
182+ class(t_array), intent(inout) :: array
183+ integer, intent(out) :: stat
184+ character(len=:), allocatable, intent(out) :: msg
185+
186+ #:for k1, t1 in KINDS_TYPES
187+ #:for rank in RANKS
188+ ${t1}$, allocatable :: array_${t1[0]}$${k1}$_${rank}$${ranksuffix(rank)}$
93189 #:endfor
94- & vshape(${rank}$)), &
95- & stat=stat)
190+ #:endfor
96191
97- end subroutine allocator
192+ integer, allocatable :: vshape(:)
98193
99- end subroutine load_npy_${t1[0]}$${k1}$_${rank}$
100- #:endfor
101- #:endfor
194+ select type (arr => array)
195+ #:for k1, t1 in KINDS_TYPES
196+ #:for rank in RANKS
197+ type is (t_array_${t1[0]}$${k1}$_${rank}$)
198+ call verify_npy_file(io, unzipped_file%name, type_${t1[0]}$${k1}$, vshape, ${rank}$, stat, msg)
199+ if (stat /= 0) return
200+ call allocate_array(array_${t1[0]}$${k1}$_${rank}$, vshape, stat)
201+ if (stat /= 0) then
202+ msg = "Failed to allocate array of type '"//type_${t1[0]}$${k1}$//"' "//&
203+ & "with total size of "//to_string(product(vshape))
204+ return
205+ end if
206+ read (io, iostat=stat) array_${t1[0]}$${k1}$_${rank}$${ranksuffix(rank)}$
207+ arr%values = array_${t1[0]}$${k1}$_${rank}$${ranksuffix(rank)}$
208+ #:endfor
209+ #:endfor
210+ class default
211+ stat = 1; msg = 'Unsupported array type.'; return
212+ end select
213+ end
214+
215+ !> Open file and try to identify the problem.
216+ module subroutine identify_problem(filename, stat, msg)
217+ character(len=*), intent(in) :: filename
218+ integer, intent(inout) :: stat
219+ character(len=:), allocatable, intent(inout) :: msg
102220
221+ logical :: exists
222+ integer :: io_unit, prev_stat
223+ character(len=:), allocatable :: prev_msg
224+
225+ ! Keep track of the previous status and message in case no reason can be found.
226+ prev_stat = stat
227+ if (allocated(msg)) call move_alloc(msg, prev_msg)
228+
229+ inquire (file=filename, exist=exists)
230+ if (.not. exists) then
231+ stat = 1; msg = 'File does not exist: '//filename//'.'; return
232+ end if
233+ open (newunit=io_unit, file=filename, form='unformatted', access='stream', &
234+ & status='old', action='read', iostat=stat, iomsg=msg)
235+ if (stat /= 0) return
236+
237+ call verify_header(io_unit, stat, msg)
238+ if (stat /= 0) return
239+
240+ ! Restore previous status and message if no reason could be found.
241+ stat = prev_stat; msg = 'Failed to unzip file: '//filename//nl//prev_msg
242+ end
243+
244+ module subroutine verify_header(io_unit, stat, msg)
245+ integer, intent(in) :: io_unit
246+ integer, intent(out) :: stat
247+ character(len=:), allocatable, intent(out) :: msg
248+
249+ integer :: file_size
250+ character(len=len(zip_prefix)) :: header
251+
252+ inquire (io_unit, size=file_size)
253+ if (file_size < len(zip_suffix)) then
254+ stat = 1; msg = 'File is too small to be an npz file.'; return
255+ end if
256+
257+ read (io_unit, iostat=stat) header
258+ if (stat /= 0) then
259+ msg = 'Failed to read header from file'; return
260+ end if
261+
262+ if (header == zip_suffix) then
263+ stat = 1; msg = 'Empty npz file.'; return
264+ end if
265+
266+ if (header /= zip_prefix) then
267+ stat = 1; msg = 'Not an npz file.'; return
268+ end if
269+ end
103270
104271 !> Read the npy header from a binary file and retrieve the descriptor string.
105272 subroutine get_descriptor(io, filename, vtype, vshape, stat, msg)
@@ -168,7 +335,7 @@ contains
168335 if (.not.fortran_order) then
169336 vshape = [(vshape(i), i = size(vshape), 1, -1)]
170337 end if
171- end subroutine get_descriptor
338+ end
172339
173340
174341 !> Parse the first eight bytes of the npy header to verify the data
@@ -214,7 +381,7 @@ contains
214381 & "'"//to_string(major)//"."//to_string(minor)//"'"
215382 return
216383 end if
217- end subroutine parse_header
384+ end
218385
219386 !> Parse the descriptor in the npy header. This routine implements a minimal
220387 !> non-recursive parser for serialized Python dictionaries.
@@ -367,7 +534,7 @@ contains
367534 & "1 | " // input // nl // &
368535 & " |" // repeat(" ", first) // repeat("^", last - first + 1) // nl // &
369536 & " |"
370- end function make_message
537+ end
371538
372539 !> Parse a tuple of integers into an array of integers
373540 subroutine parse_tuple(input, pos, tuple, stat, msg)
@@ -427,7 +594,7 @@ contains
427594 return
428595 end select
429596 end do
430- end subroutine parse_tuple
597+ end
431598
432599 !> Get the next allowed token
433600 subroutine next_token(input, pos, token, allowed_token, stat, msg)
@@ -459,7 +626,7 @@ contains
459626 exit
460627 end if
461628 end do
462- end subroutine next_token
629+ end
463630
464631 !> Tokenize input string
465632 subroutine get_token(input, pos, token)
@@ -531,8 +698,8 @@ contains
531698 token = token_type(pos, pos, invalid)
532699 end select
533700
534- end subroutine get_token
701+ end
535702
536- end subroutine parse_descriptor
703+ end
537704
538- end submodule stdlib_io_npy_load
705+ end
0 commit comments