4545 _write_data ,
4646 _ftype4scaled_finite ,
4747 )
48- from ..openers import Opener , BZ2File
48+ from ..openers import Opener , BZ2File , HAVE_ZSTD
4949from ..casting import (floor_log2 , type_info , OK_FLOATS , shared_range )
5050
5151from ..deprecator import ExpiredDeprecationError
5656
5757from nibabel .testing import nullcontext , assert_dt_equal , assert_allclose_safely , suppress_warnings
5858
59+ # only import ZstdFile, if installed
60+ if HAVE_ZSTD :
61+ from ..openers import ZstdFile
62+
5963#: convenience variables for numpy types
6064FLOAT_TYPES = np .sctypes ['float' ]
6165COMPLEX_TYPES = np .sctypes ['complex' ]
6872def test__is_compressed_fobj ():
6973 # _is_compressed helper function
7074 with InTemporaryDirectory ():
71- for ext , opener , compressed in (('' , open , False ),
72- ('.gz' , gzip .open , True ),
73- ('.bz2' , BZ2File , True )):
75+ file_openers = [('' , open , False ),
76+ ('.gz' , gzip .open , True ),
77+ ('.bz2' , BZ2File , True )]
78+ if HAVE_ZSTD :
79+ file_openers += [('.zst' , ZstdFile , True )]
80+ for ext , opener , compressed in file_openers :
7481 fname = 'test.bin' + ext
7582 for mode in ('wb' , 'rb' ):
7683 fobj = opener (fname , mode )
@@ -88,12 +95,15 @@ def make_array(n, bytes):
8895 arr .flags .writeable = True
8996 return arr
9097
91- # Check whether file, gzip file, bz2 file reread memory from cache
98+ # Check whether file, gzip file, bz2, zst file reread memory from cache
9299 fname = 'test.bin'
93100 with InTemporaryDirectory ():
101+ openers = [open , gzip .open , BZ2File ]
102+ if HAVE_ZSTD :
103+ openers += [ZstdFile ]
94104 for n , opener in itertools .product (
95105 (256 , 1024 , 2560 , 25600 ),
96- ( open , gzip . open , BZ2File ) ):
106+ openers ):
97107 in_arr = np .arange (n , dtype = dtype )
98108 # Write array to file
99109 fobj_w = opener (fname , 'wb' )
@@ -230,7 +240,10 @@ def test_array_from_file_openers():
230240 dtype = np .dtype (np .float32 )
231241 in_arr = np .arange (24 , dtype = dtype ).reshape (shape )
232242 with InTemporaryDirectory ():
233- for ext , offset in itertools .product (('' , '.gz' , '.bz2' ),
243+ extensions = ['' , '.gz' , '.bz2' ]
244+ if HAVE_ZSTD :
245+ extensions += ['.zst' ]
246+ for ext , offset in itertools .product (extensions ,
234247 (0 , 5 , 10 )):
235248 fname = 'test.bin' + ext
236249 with Opener (fname , 'wb' ) as out_buf :
@@ -251,9 +264,12 @@ def test_array_from_file_reread():
251264 offset = 9
252265 fname = 'test.bin'
253266 with InTemporaryDirectory ():
267+ openers = [open , gzip .open , bz2 .BZ2File , BytesIO ]
268+ if HAVE_ZSTD :
269+ openers += [ZstdFile ]
254270 for shape , opener , dtt , order in itertools .product (
255271 ((64 ,), (64 , 65 ), (64 , 65 , 66 )),
256- ( open , gzip . open , bz2 . BZ2File , BytesIO ) ,
272+ openers ,
257273 (np .int16 , np .float32 ),
258274 ('F' , 'C' )):
259275 n_els = np .prod (shape )
@@ -901,7 +917,9 @@ def test_write_zeros():
901917def test_seek_tell ():
902918 # Test seek tell routine
903919 bio = BytesIO ()
904- in_files = bio , 'test.bin' , 'test.gz' , 'test.bz2'
920+ in_files = [bio , 'test.bin' , 'test.gz' , 'test.bz2' ]
921+ if HAVE_ZSTD :
922+ in_files += ['test.zst' ]
905923 start = 10
906924 end = 100
907925 diff = end - start
@@ -920,9 +938,12 @@ def test_seek_tell():
920938 fobj .write (b'\x01 ' * start )
921939 assert fobj .tell () == start
922940 # Files other than BZ2Files can seek forward on write, leaving
923- # zeros in their wake. BZ2Files can't seek when writing, unless
924- # we enable the write0 flag to seek_tell
925- if not write0 and in_file == 'test.bz2' : # Can't seek write in bz2
941+ # zeros in their wake. BZ2Files can't seek when writing,
942+ # unless we enable the write0 flag to seek_tell
943+ # ZstdFiles also does not support seek forward on write
944+ if (not write0 and
945+ (in_file == 'test.bz2' or
946+ in_file == 'test.zst' )): # Can't seek write in bz2, zst
926947 # write the zeros by hand for the read test below
927948 fobj .write (b'\x00 ' * diff )
928949 else :
@@ -946,7 +967,10 @@ def test_seek_tell():
946967 # Check we have the expected written output
947968 with ImageOpener (in_file , 'rb' ) as fobj :
948969 assert fobj .read () == b'\x01 ' * start + b'\x00 ' * diff + b'\x02 ' * tail
949- for in_file in ('test2.gz' , 'test2.bz2' ):
970+ input_files = ['test2.gz' , 'test2.bz2' ]
971+ if HAVE_ZSTD :
972+ input_files += ['test2.zst' ]
973+ for in_file in input_files :
950974 # Check failure of write seek backwards
951975 with ImageOpener (in_file , 'wb' ) as fobj :
952976 fobj .write (b'g' * 10 )
0 commit comments