@@ -36,11 +36,27 @@ contains
3636
3737 open(newunit=io, file=filename, form="unformatted", access="stream", iostat=stat)
3838 catch: block
39+ character(len=:), allocatable :: this_type
3940 integer, allocatable :: vshape(:)
4041
41- call verify_npy_file (io, filename, vtype , vshape, rank , stat, msg)
42+ call get_descriptor (io, filename, this_type , vshape, stat, msg)
4243 if (stat /= 0) exit catch
4344
45+ if (this_type /= vtype) then
46+ stat = 1
47+ msg = "File '"//filename//"' contains data of type '"//this_type//"', "//&
48+ & "but expected '"//vtype//"'"
49+ exit catch
50+ end if
51+
52+ if (size(vshape) /= rank) then
53+ stat = 1
54+ msg = "File '"//filename//"' contains data of rank "//&
55+ & to_string(size(vshape))//", but expected "//&
56+ & to_string(rank)
57+ exit catch
58+ end if
59+
4460 call allocate_array(array, vshape, stat)
4561 if (stat /= 0) then
4662 msg = "Failed to allocate array of type '"//vtype//"' "//&
@@ -67,44 +83,6 @@ contains
6783 #:endfor
6884#:endfor
6985
70- !> Verify header, type and rank of the npy file.
71- subroutine verify_npy_file(io, filename, vtype, vshape, rank, stat, msg)
72- !> Access unit to the npy file.
73- integer, intent(in) :: io
74- !> Name of the npy file to load from.
75- character(len=*), intent(in) :: filename
76- !> Type of the data stored, retrieved from field `descr`.
77- character(len=*), intent(in) :: vtype
78- !> Shape of the stored data, retrieved from field `shape`.
79- integer, allocatable, intent(out) :: vshape(:)
80- !> Expected rank of the data.
81- integer, intent(in) :: rank
82- !> Status of operation.
83- integer, intent(out) :: stat
84- !> Associated error message in case of non-zero status.
85- character(len=:), allocatable, intent(out) :: msg
86-
87- character(len=:), allocatable :: this_type
88-
89- call get_descriptor(io, filename, this_type, vshape, stat, msg)
90- if (stat /= 0) return
91-
92- if (this_type /= vtype) then
93- stat = 1
94- msg = "File '"//filename//"' contains data of type '"//this_type//"', "//&
95- & "but expected '"//vtype//"'"
96- return
97- end if
98-
99- if (size(vshape) /= rank) then
100- stat = 1
101- msg = "File '"//filename//"' contains data of rank "//&
102- & to_string(size(vshape))//", but expected "//&
103- & to_string(rank)
104- return
105- end if
106- end
107-
10886 #:for k1, t1 in KINDS_TYPES
10987 #:for rank in RANKS
11088 module subroutine allocate_array_${t1[0]}$${k1}$_${rank}$(array, vshape, stat)
@@ -126,9 +104,9 @@ contains
126104 !>
127105 !> Load multidimensional arrays from a compressed or uncompressed npz file.
128106 !> ([Specification](../page/specs/stdlib_io.html#load_npz))
129- module subroutine load_npz_to_bundle (filename, array_bundle , iostat, iomsg)
107+ module subroutine load_npz_to_arrays (filename, arrays , iostat, iomsg)
130108 character(len=*), intent(in) :: filename
131- type(t_array_bundle ), intent(out) :: array_bundle
109+ type(t_array_wrapper ), allocatable, intent(out) :: arrays(:)
132110 integer, intent(out), optional :: iostat
133111 character(len=:), allocatable, intent(out), optional :: iomsg
134112
@@ -138,9 +116,9 @@ contains
138116
139117 call unzip(filename, unzipped_bundle, stat, msg)
140118 if (stat == 0) then
141- call load_raw_to_bundle (unzipped_bundle, array_bundle , stat, msg)
119+ call load_unzipped_bundle_to_arrays (unzipped_bundle, arrays , stat, msg)
142120 else
143- call identify_problem (filename, stat, msg)
121+ call identify_unzip_problem (filename, stat, msg)
144122 end if
145123
146124 if (present(iostat)) then
@@ -156,64 +134,77 @@ contains
156134 if (present(iomsg) .and. allocated(msg)) call move_alloc(msg, iomsg)
157135 end
158136
159- module subroutine load_raw_to_bundle (unzipped_bundle, array_bundle , stat, msg)
137+ module subroutine load_unzipped_bundle_to_arrays (unzipped_bundle, arrays , stat, msg)
160138 type(t_unzipped_bundle), intent(in) :: unzipped_bundle
161- type(t_array_bundle ), intent(out) :: array_bundle
139+ type(t_array_wrapper ), allocatable, intent(out) :: arrays(:)
162140 integer, intent(out) :: stat
163141 character(len=:), allocatable, intent(out) :: msg
164142
165143 integer :: i, io
144+ integer, allocatable :: vshape(:)
145+ character(len=:), allocatable :: this_type
146+
147+ allocate (arrays(size(unzipped_bundle%files)))
166148
167- allocate (array_bundle%files(size(unzipped_bundle%files)))
168149 do i = 1, size(unzipped_bundle%files)
169- array_bundle%files(i)%name = unzipped_bundle%files(i)%name
170- open (newunit=io, status='scratch', form='unformatted', access='stream', iostat=stat)
171- if (stat /= 0) return
172- write (io) unzipped_bundle%files(i)%data
173- call load_string_to_array(io, unzipped_bundle%files(i), array_bundle%files(i), stat, msg)
174- close (io, status='delete', iostat=stat)
150+ open (newunit=io, status='scratch', form='unformatted', access='stream', iostat=stat, iomsg=msg)
175151 if (stat /= 0) return
176- end do
177- end
178-
179- module subroutine load_string_to_array(io, unzipped_file, array, stat, msg)
180- integer, intent(in) :: io
181- type(t_unzipped_file), intent(in) :: unzipped_file
182- class(t_array), intent(inout) :: array
183- integer, intent(out) :: stat
184- character(len=:), allocatable, intent(out) :: msg
185152
186- #:for k1, t1 in KINDS_TYPES
187- #:for rank in RANKS
188- ${t1}$, allocatable :: array_${t1[0]}$${k1}$_${rank}$${ranksuffix(rank)}$
189- #:endfor
190- #:endfor
153+ write (io, iostat=stat) unzipped_bundle%files(i)%data
154+ if (stat /= 0) then
155+ msg = 'Failed to write unzipped data to scratch file.'
156+ close (io, status='delete'); return
157+ end if
191158
192- integer, allocatable :: vshape(:)
159+ rewind (io)
160+ call get_descriptor(io, unzipped_bundle%files(i)%name, this_type, vshape, stat, msg)
161+ if (stat /= 0) return
193162
194- select type (arr => array )
163+ select case (this_type )
195164 #:for k1, t1 in KINDS_TYPES
165+ case (type_${t1[0]}$${k1}$)
166+ select case (size(vshape))
196167 #:for rank in RANKS
197- type is (t_array_${t1[0]}$${k1}$_${rank}$)
198- call verify_npy_file(io, unzipped_file%name, type_${t1[0]}$${k1}$, vshape, ${rank}$, stat, msg)
199- if (stat /= 0) return
200- call allocate_array(array_${t1[0]}$${k1}$_${rank}$, vshape, stat)
201- if (stat /= 0) then
202- msg = "Failed to allocate array of type '"//type_${t1[0]}$${k1}$//"' "//&
203- & "with total size of "//to_string(product(vshape))
204- return
205- end if
206- read (io, iostat=stat) array_${t1[0]}$${k1}$_${rank}$${ranksuffix(rank)}$
207- arr%values = array_${t1[0]}$${k1}$_${rank}$${ranksuffix(rank)}$
168+ case (${rank}$)
169+ block
170+ ${t1}$, allocatable :: array${ranksuffix(rank)}$
171+
172+ call allocate_array(array, vshape, stat)
173+ if (stat /= 0) then
174+ msg = "Failed to allocate array of type '"//this_type//"'."; return
175+ end if
176+
177+ read (io, iostat=stat) array
178+ if (stat /= 0) then
179+ msg = "Failed to read array of type '"//this_type//"' "//&
180+ & 'with total size of '//to_string(product(vshape)); return
181+ end if
182+
183+ call arrays(i)%allocate_array(array, stat, msg)
184+ if (stat /= 0) then
185+ msg = "Failed to allocate array of type '"//this_type//"' "//&
186+ & 'with total size of '//to_string(product(vshape)); return
187+ end if
188+
189+ arrays(i)%array%name = unzipped_bundle%files(i)%name
190+ end block
208191 #:endfor
192+ case default
193+ stat = 1; msg = 'Unsupported rank for array of type '//this_type//': '// &
194+ & to_string(size(vshape))//'.'; return
195+ end select
209196 #:endfor
210- class default
211- stat = 1; msg = 'Unsupported array type.'; return
212- end select
197+ case default
198+ stat = 1; msg = 'Unsupported array type: '//this_type//'.'; return
199+ end select
200+
201+ close (io, status='delete')
202+ if (stat /= 0) return
203+ end do
213204 end
214205
215- !> Open file and try to identify the problem .
216- module subroutine identify_problem (filename, stat, msg)
206+ !> Open file and try to identify the cause of the error that occurred during unzip .
207+ module subroutine identify_unzip_problem (filename, stat, msg)
217208 character(len=*), intent(in) :: filename
218209 integer, intent(inout) :: stat
219210 character(len=:), allocatable, intent(inout) :: msg
@@ -291,7 +282,7 @@ contains
291282
292283 ! stat should be zero if no error occurred
293284 stat = 0
294-
285+
295286 read(io, iostat=stat) header
296287 if (stat /= 0) return
297288
0 commit comments