1616import time
1717
1818from numpy .compat .py3k import asstr , asbytes
19- from ..openers import Opener , ImageOpener , HAVE_INDEXED_GZIP , BZ2File , DeterministicGzipFile
19+ from ..openers import (Opener ,
20+ ImageOpener ,
21+ HAVE_INDEXED_GZIP ,
22+ BZ2File ,
23+ DeterministicGzipFile ,
24+ )
2025from ..tmpdirs import InTemporaryDirectory
2126from ..volumeutils import BinOpener
27+ from ..optpkg import optional_package
2228
2329import unittest
2430from unittest import mock
2531import pytest
2632from ..testing import error_warnings
2733
34+ pyzstd , HAVE_ZSTD , _ = optional_package ("pyzstd" )
35+
2836
2937class Lunk (object ):
3038 # bare file-like for testing
@@ -73,10 +81,13 @@ def test_Opener_various():
7381 import indexed_gzip as igzip
7482 with InTemporaryDirectory ():
7583 sobj = BytesIO ()
76- for input in ('test.txt' ,
77- 'test.txt.gz' ,
78- 'test.txt.bz2' ,
79- sobj ):
84+ files_to_test = ['test.txt' ,
85+ 'test.txt.gz' ,
86+ 'test.txt.bz2' ,
87+ sobj ]
88+ if HAVE_ZSTD :
89+ files_to_test += ['test.txt.zst' ]
90+ for input in files_to_test :
8091 with Opener (input , 'wb' ) as fobj :
8192 fobj .write (message )
8293 assert fobj .tell () == len (message )
@@ -242,6 +253,8 @@ def test_compressed_ext_case():
242253 class StrictOpener (Opener ):
243254 compress_ext_icase = False
244255 exts = ('gz' , 'bz2' , 'GZ' , 'gZ' , 'BZ2' , 'Bz2' )
256+ if HAVE_ZSTD :
257+ exts += ('zst' , 'ZST' , 'Zst' )
245258 with InTemporaryDirectory ():
246259 # Make a basic file to check type later
247260 with open (__file__ , 'rb' ) as a_file :
@@ -266,6 +279,8 @@ class StrictOpener(Opener):
266279 except ImportError :
267280 IndexedGzipFile = GzipFile
268281 assert isinstance (fobj .fobj , (GzipFile , IndexedGzipFile ))
282+ elif lext == 'zst' :
283+ assert isinstance (fobj .fobj , pyzstd .ZstdFile )
269284 else :
270285 assert isinstance (fobj .fobj , BZ2File )
271286
@@ -275,11 +290,14 @@ def test_name():
275290 sobj = BytesIO ()
276291 lunk = Lunk ('in ART' )
277292 with InTemporaryDirectory ():
278- for input in ('test.txt' ,
279- 'test.txt.gz' ,
280- 'test.txt.bz2' ,
281- sobj ,
282- lunk ):
293+ files_to_test = ['test.txt' ,
294+ 'test.txt.gz' ,
295+ 'test.txt.bz2' ,
296+ sobj ,
297+ lunk ]
298+ if HAVE_ZSTD :
299+ files_to_test += ['test.txt.zst' ]
300+ for input in files_to_test :
283301 exp_name = input if type (input ) == type ('' ) else None
284302 with Opener (input , 'wb' ) as fobj :
285303 assert fobj .name == exp_name
@@ -331,10 +349,13 @@ def test_iter():
331349""" .split ('\n ' )
332350 with InTemporaryDirectory ():
333351 sobj = BytesIO ()
334- for input , does_t in (('test.txt' , True ),
335- ('test.txt.gz' , False ),
336- ('test.txt.bz2' , False ),
337- (sobj , True )):
352+ files_to_test = [('test.txt' , True ),
353+ ('test.txt.gz' , False ),
354+ ('test.txt.bz2' , False ),
355+ (sobj , True )]
356+ if HAVE_ZSTD :
357+ files_to_test += [('test.txt.zst' , False )]
358+ for input , does_t in files_to_test :
338359 with Opener (input , 'wb' ) as fobj :
339360 for line in lines :
340361 fobj .write (asbytes (line + os .linesep ))
0 commit comments