77#
88### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
99"""Context manager openers for various fileobject types"""
10+ from __future__ import annotations
11+
1012import gzip
11- import warnings
13+ import io
14+ import typing as ty
1215from bz2 import BZ2File
1316from os .path import splitext
1417
15- from packaging .version import Version
16-
1718from nibabel .optpkg import optional_package
1819
19- # is indexed_gzip present and modern?
20- try :
21- import indexed_gzip as igzip # type: ignore
20+ if ty .TYPE_CHECKING : # pragma: no cover
21+ from types import TracebackType
2222
23- version = igzip .__version__
23+ import pyzstd
24+ from _typeshed import WriteableBuffer
2425
25- HAVE_INDEXED_GZIP = True
26+ ModeRT = ty .Literal ['r' , 'rt' ]
27+ ModeRB = ty .Literal ['rb' ]
28+ ModeWT = ty .Literal ['w' , 'wt' ]
29+ ModeWB = ty .Literal ['wb' ]
30+ ModeR = ty .Union [ModeRT , ModeRB ]
31+ ModeW = ty .Union [ModeWT , ModeWB ]
32+ Mode = ty .Union [ModeR , ModeW ]
2633
27- # < 0.7 - no good
28- if Version (version ) < Version ('0.7.0' ):
29- warnings .warn (f'indexed_gzip is present, but too old (>= 0.7.0 required): { version } )' )
30- HAVE_INDEXED_GZIP = False
31- # >= 0.8 SafeIndexedGzipFile renamed to IndexedGzipFile
32- elif Version (version ) < Version ('0.8.0' ):
33- IndexedGzipFile = igzip .SafeIndexedGzipFile
34- else :
35- IndexedGzipFile = igzip .IndexedGzipFile
36- del igzip , version
34+ OpenerDef = tuple [ty .Callable [..., io .IOBase ], tuple [str , ...]]
35+ else :
36+ pyzstd = optional_package ('pyzstd' )[0 ]
37+
38+
39+ @ty .runtime_checkable
40+ class Fileish (ty .Protocol ):
41+ def read (self , size : int = - 1 , / ) -> bytes :
42+ ... # pragma: no cover
43+
44+ def write (self , b : bytes , / ) -> int | None :
45+ ... # pragma: no cover
46+
47+
48+ try :
49+ from indexed_gzip import IndexedGzipFile # type: ignore
3750
51+ HAVE_INDEXED_GZIP = True
3852except ImportError :
3953 # nibabel.openers.IndexedGzipFile is imported by nibabel.volumeutils
4054 # to detect compressed file types, so we give a fallback value here.
@@ -49,35 +63,63 @@ class DeterministicGzipFile(gzip.GzipFile):
4963 to a modification time (``mtime``) of 0 seconds.
5064 """
5165
52- def __init__ (self , filename = None , mode = None , compresslevel = 9 , fileobj = None , mtime = 0 ):
53- # These two guards are copied from
66+ def __init__ (
67+ self ,
68+ filename : str | None = None ,
69+ mode : Mode | None = None ,
70+ compresslevel : int = 9 ,
71+ fileobj : io .FileIO | None = None ,
72+ mtime : int = 0 ,
73+ ):
74+ if mode is None :
75+ mode = 'rb'
76+ modestr : str = mode
77+
78+ # These two guards are adapted from
5479 # https://github.com/python/cpython/blob/6ab65c6/Lib/gzip.py#L171-L174
55- if mode and 'b' not in mode :
56- mode += ' b'
80+ if 'b' not in modestr :
81+ modestr = f' { mode } b'
5782 if fileobj is None :
58- fileobj = self .myfileobj = open (filename , mode or 'rb' )
83+ if filename is None :
84+ raise TypeError ('Must define either fileobj or filename' )
85+ # Cast because GzipFile.myfileobj has type io.FileIO while open returns ty.IO
86+ fileobj = self .myfileobj = ty .cast (io .FileIO , open (filename , modestr ))
5987 return super ().__init__ (
60- filename = '' , mode = mode , compresslevel = compresslevel , fileobj = fileobj , mtime = mtime
88+ filename = '' ,
89+ mode = modestr ,
90+ compresslevel = compresslevel ,
91+ fileobj = fileobj ,
92+ mtime = mtime ,
6193 )
6294
6395
64- def _gzip_open (filename , mode = 'rb' , compresslevel = 9 , mtime = 0 , keep_open = False ):
96+ def _gzip_open (
97+ filename : str ,
98+ mode : Mode = 'rb' ,
99+ compresslevel : int = 9 ,
100+ mtime : int = 0 ,
101+ keep_open : bool = False ,
102+ ) -> gzip .GzipFile :
103+
104+ if not HAVE_INDEXED_GZIP or mode != 'rb' :
105+ gzip_file = DeterministicGzipFile (filename , mode , compresslevel , mtime = mtime )
65106
66107 # use indexed_gzip if possible for faster read access. If keep_open ==
67108 # True, we tell IndexedGzipFile to keep the file handle open. Otherwise
68109 # the IndexedGzipFile will close/open the file on each read.
69- if HAVE_INDEXED_GZIP and mode == 'rb' :
70- gzip_file = IndexedGzipFile (filename , drop_handles = not keep_open )
71-
72- # Fall-back to built-in GzipFile
73110 else :
74- gzip_file = DeterministicGzipFile (filename , mode , compresslevel , mtime = mtime )
111+ gzip_file = IndexedGzipFile (filename , drop_handles = not keep_open )
75112
76113 return gzip_file
77114
78115
79- def _zstd_open (filename , mode = 'r' , * , level_or_option = None , zstd_dict = None ):
80- pyzstd = optional_package ('pyzstd' )[0 ]
116+ def _zstd_open (
117+ filename : str ,
118+ mode : Mode = 'r' ,
119+ * ,
120+ level_or_option : int | dict | None = None ,
121+ zstd_dict : pyzstd .ZstdDict | None = None ,
122+ ) -> pyzstd .ZstdFile :
81123 return pyzstd .ZstdFile (filename , mode , level_or_option = level_or_option , zstd_dict = zstd_dict )
82124
83125
@@ -104,7 +146,7 @@ class Opener:
104146 gz_def = (_gzip_open , ('mode' , 'compresslevel' , 'mtime' , 'keep_open' ))
105147 bz2_def = (BZ2File , ('mode' , 'buffering' , 'compresslevel' ))
106148 zstd_def = (_zstd_open , ('mode' , 'level_or_option' , 'zstd_dict' ))
107- compress_ext_map = {
149+ compress_ext_map : dict [ str | None , OpenerDef ] = {
108150 '.gz' : gz_def ,
109151 '.bz2' : bz2_def ,
110152 '.zst' : zstd_def ,
@@ -121,19 +163,19 @@ class Opener:
121163 'w' : default_zst_compresslevel ,
122164 }
123165 #: whether to ignore case looking for compression extensions
124- compress_ext_icase = True
166+ compress_ext_icase : bool = True
167+
168+ fobj : io .IOBase
125169
126- def __init__ (self , fileish , * args , ** kwargs ):
127- if self . _is_fileobj (fileish ):
170+ def __init__ (self , fileish : str | io . IOBase , * args , ** kwargs ):
171+ if isinstance (fileish , ( io . IOBase , Fileish ) ):
128172 self .fobj = fileish
129173 self .me_opened = False
130- self ._name = None
174+ self ._name = getattr ( fileish , 'name' , None )
131175 return
132176 opener , arg_names = self ._get_opener_argnames (fileish )
133177 # Get full arguments to check for mode and compresslevel
134- full_kwargs = kwargs .copy ()
135- n_args = len (args )
136- full_kwargs .update (dict (zip (arg_names [:n_args ], args )))
178+ full_kwargs = {** kwargs , ** dict (zip (arg_names , args ))}
137179 # Set default mode
138180 if 'mode' not in full_kwargs :
139181 mode = 'rb'
@@ -155,7 +197,7 @@ def __init__(self, fileish, *args, **kwargs):
155197 self ._name = fileish
156198 self .me_opened = True
157199
158- def _get_opener_argnames (self , fileish ) :
200+ def _get_opener_argnames (self , fileish : str ) -> OpenerDef :
159201 _ , ext = splitext (fileish )
160202 if self .compress_ext_icase :
161203 ext = ext .lower ()
@@ -168,16 +210,12 @@ def _get_opener_argnames(self, fileish):
168210 return self .compress_ext_map [ext ]
169211 return self .compress_ext_map [None ]
170212
171- def _is_fileobj (self , obj ):
172- """Is `obj` a file-like object?"""
173- return hasattr (obj , 'read' ) and hasattr (obj , 'write' )
174-
175213 @property
176- def closed (self ):
214+ def closed (self ) -> bool :
177215 return self .fobj .closed
178216
179217 @property
180- def name (self ):
218+ def name (self ) -> str | None :
181219 """Return ``self.fobj.name`` or self._name if not present
182220
183221 self._name will be None if object was created with a fileobj, otherwise
@@ -186,42 +224,53 @@ def name(self):
186224 return self ._name
187225
188226 @property
189- def mode (self ):
190- return self .fobj .mode
227+ def mode (self ) -> str :
228+ # Check and raise our own error for type narrowing purposes
229+ if hasattr (self .fobj , 'mode' ):
230+ return self .fobj .mode
231+ raise AttributeError (f'{ self .fobj .__class__ .__name__ } has no attribute "mode"' )
191232
192- def fileno (self ):
233+ def fileno (self ) -> int :
193234 return self .fobj .fileno ()
194235
195- def read (self , * args , ** kwargs ) :
196- return self .fobj .read (* args , ** kwargs )
236+ def read (self , size : int = - 1 , / ) -> bytes :
237+ return self .fobj .read (size )
197238
198- def readinto (self , * args , ** kwargs ):
199- return self .fobj .readinto (* args , ** kwargs )
239+ def readinto (self , buffer : WriteableBuffer , / ) -> int | None :
240+ # Check and raise our own error for type narrowing purposes
241+ if hasattr (self .fobj , 'readinto' ):
242+ return self .fobj .readinto (buffer )
243+ raise AttributeError (f'{ self .fobj .__class__ .__name__ } has no attribute "readinto"' )
200244
201- def write (self , * args , ** kwargs ) :
202- return self .fobj .write (* args , ** kwargs )
245+ def write (self , b : bytes , / ) -> int | None :
246+ return self .fobj .write (b )
203247
204- def seek (self , * args , ** kwargs ) :
205- return self .fobj .seek (* args , ** kwargs )
248+ def seek (self , pos : int , whence : int = 0 , / ) -> int :
249+ return self .fobj .seek (pos , whence )
206250
207- def tell (self , * args , ** kwargs ) :
208- return self .fobj .tell (* args , ** kwargs )
251+ def tell (self , / ) -> int :
252+ return self .fobj .tell ()
209253
210- def close (self , * args , ** kwargs ) :
211- return self .fobj .close (* args , ** kwargs )
254+ def close (self , / ) -> None :
255+ return self .fobj .close ()
212256
213- def __iter__ (self ):
257+ def __iter__ (self ) -> ty . Iterator [ bytes ] :
214258 return iter (self .fobj )
215259
216- def close_if_mine (self ):
260+ def close_if_mine (self ) -> None :
217261 """Close ``self.fobj`` iff we opened it in the constructor"""
218262 if self .me_opened :
219263 self .close ()
220264
221- def __enter__ (self ):
265+ def __enter__ (self ) -> Opener :
222266 return self
223267
224- def __exit__ (self , exc_type , exc_val , exc_tb ):
268+ def __exit__ (
269+ self ,
270+ exc_type : type [BaseException ] | None ,
271+ exc_val : BaseException | None ,
272+ exc_tb : TracebackType | None ,
273+ ) -> None :
225274 self .close_if_mine ()
226275
227276
0 commit comments