1414from distutils .version import StrictVersion
1515
1616from numpy .compat .py3k import asstr , asbytes
17- from ..openers import Opener , ImageOpener , HAVE_INDEXED_GZIP , BZ2File
17+ from ..openers import (Opener ,
18+ ImageOpener ,
19+ HAVE_INDEXED_GZIP ,
20+ BZ2File ,
21+ )
1822from ..tmpdirs import InTemporaryDirectory
1923from ..volumeutils import BinOpener
24+ from ..optpkg import optional_package
2025
2126import unittest
2227from unittest import mock
2328import pytest
2429from ..testing import error_warnings
2530
31+ pyzstd , HAVE_ZSTD , _ = optional_package ("pyzstd" )
32+
2633
2734class Lunk (object ):
2835 # bare file-like for testing
@@ -71,10 +78,13 @@ def test_Opener_various():
7178 import indexed_gzip as igzip
7279 with InTemporaryDirectory ():
7380 sobj = BytesIO ()
74- for input in ('test.txt' ,
75- 'test.txt.gz' ,
76- 'test.txt.bz2' ,
77- sobj ):
81+ files_to_test = ['test.txt' ,
82+ 'test.txt.gz' ,
83+ 'test.txt.bz2' ,
84+ sobj ]
85+ if HAVE_ZSTD :
86+ files_to_test += ['test.txt.zst' ]
87+ for input in files_to_test :
7888 with Opener (input , 'wb' ) as fobj :
7989 fobj .write (message )
8090 assert fobj .tell () == len (message )
@@ -240,6 +250,8 @@ def test_compressed_ext_case():
240250 class StrictOpener (Opener ):
241251 compress_ext_icase = False
242252 exts = ('gz' , 'bz2' , 'GZ' , 'gZ' , 'BZ2' , 'Bz2' )
253+ if HAVE_ZSTD :
254+ exts += ('zst' , 'ZST' , 'Zst' )
243255 with InTemporaryDirectory ():
244256 # Make a basic file to check type later
245257 with open (__file__ , 'rb' ) as a_file :
@@ -264,6 +276,8 @@ class StrictOpener(Opener):
264276 except ImportError :
265277 IndexedGzipFile = GzipFile
266278 assert isinstance (fobj .fobj , (GzipFile , IndexedGzipFile ))
279+ elif lext == 'zst' :
280+ assert isinstance (fobj .fobj , pyzstd .ZstdFile )
267281 else :
268282 assert isinstance (fobj .fobj , BZ2File )
269283
@@ -273,11 +287,14 @@ def test_name():
273287 sobj = BytesIO ()
274288 lunk = Lunk ('in ART' )
275289 with InTemporaryDirectory ():
276- for input in ('test.txt' ,
277- 'test.txt.gz' ,
278- 'test.txt.bz2' ,
279- sobj ,
280- lunk ):
290+ files_to_test = ['test.txt' ,
291+ 'test.txt.gz' ,
292+ 'test.txt.bz2' ,
293+ sobj ,
294+ lunk ]
295+ if HAVE_ZSTD :
296+ files_to_test += ['test.txt.zst' ]
297+ for input in files_to_test :
281298 exp_name = input if type (input ) == type ('' ) else None
282299 with Opener (input , 'wb' ) as fobj :
283300 assert fobj .name == exp_name
@@ -329,10 +346,13 @@ def test_iter():
329346""" .split ('\n ' )
330347 with InTemporaryDirectory ():
331348 sobj = BytesIO ()
332- for input , does_t in (('test.txt' , True ),
333- ('test.txt.gz' , False ),
334- ('test.txt.bz2' , False ),
335- (sobj , True )):
349+ files_to_test = [('test.txt' , True ),
350+ ('test.txt.gz' , False ),
351+ ('test.txt.bz2' , False ),
352+ (sobj , True )]
353+ if HAVE_ZSTD :
354+ files_to_test += [('test.txt.zst' , False )]
355+ for input , does_t in files_to_test :
336356 with Opener (input , 'wb' ) as fobj :
337357 for line in lines :
338358 fobj .write (asbytes (line + os .linesep ))
0 commit comments