131131"""
132132from __future__ import annotations
133133
134- from typing import Type
134+ import io
135+ import typing as ty
136+ from typing import Literal , Sequence
135137
136138import numpy as np
137139
140+ from .arrayproxy import ArrayLike
138141from .dataobj_images import DataobjImage
139- from .filebasedimages import ImageFileError # noqa
140- from .filebasedimages import FileBasedHeader
142+ from .filebasedimages import FileBasedHeader , FileBasedImage , FileMap
141143from .fileslice import canonical_slicers
142144from .orientations import apply_orientation , inv_ornt_aff
143145from .viewers import OrthoSlicer3D
148150except ImportError : # PY38
149151 from functools import lru_cache as cache
150152
153+ if ty .TYPE_CHECKING : # pragma: no cover
154+ import numpy .typing as npt
155+
156+ SpatialImgT = ty .TypeVar ('SpatialImgT' , bound = 'SpatialImage' )
157+ SpatialHdrT = ty .TypeVar ('SpatialHdrT' , bound = 'SpatialHeader' )
158+
159+
160+ class HasDtype (ty .Protocol ):
161+ def get_data_dtype (self ) -> np .dtype :
162+ ... # pragma: no cover
163+
164+ def set_data_dtype (self , dtype : npt .DTypeLike ) -> None :
165+ ... # pragma: no cover
166+
167+
168+ @ty .runtime_checkable
169+ class SpatialProtocol (ty .Protocol ):
170+ def get_data_dtype (self ) -> np .dtype :
171+ ... # pragma: no cover
172+
173+ def get_data_shape (self ) -> ty .Tuple [int , ...]:
174+ ... # pragma: no cover
175+
176+ def get_zooms (self ) -> ty .Tuple [float , ...]:
177+ ... # pragma: no cover
178+
151179
152180class HeaderDataError (Exception ):
153181 """Class to indicate error in getting or setting header data"""
@@ -157,21 +185,33 @@ class HeaderTypeError(Exception):
157185 """Class to indicate error in parameters into header functions"""
158186
159187
160- class SpatialHeader (FileBasedHeader ):
188+ class SpatialHeader (FileBasedHeader , SpatialProtocol ):
161189 """Template class to implement header protocol"""
162190
163- default_x_flip = True
164- data_layout = 'F'
191+ default_x_flip : bool = True
192+ data_layout : Literal [ 'F' , 'C' ] = 'F'
165193
166- def __init__ (self , data_dtype = np .float32 , shape = (0 ,), zooms = None ):
194+ _dtype : np .dtype
195+ _shape : tuple [int , ...]
196+ _zooms : tuple [float , ...]
197+
198+ def __init__ (
199+ self ,
200+ data_dtype : npt .DTypeLike = np .float32 ,
201+ shape : Sequence [int ] = (0 ,),
202+ zooms : Sequence [float ] | None = None ,
203+ ):
167204 self .set_data_dtype (data_dtype )
168205 self ._zooms = ()
169206 self .set_data_shape (shape )
170207 if zooms is not None :
171208 self .set_zooms (zooms )
172209
173210 @classmethod
174- def from_header (klass , header = None ):
211+ def from_header (
212+ klass : type [SpatialHdrT ],
213+ header : SpatialProtocol | FileBasedHeader | ty .Mapping | None = None ,
214+ ) -> SpatialHdrT :
175215 if header is None :
176216 return klass ()
177217 # I can't do isinstance here because it is not necessarily true
@@ -180,74 +220,68 @@ def from_header(klass, header=None):
180220 # different field names
181221 if type (header ) == klass :
182222 return header .copy ()
183- return klass (header .get_data_dtype (), header .get_data_shape (), header .get_zooms ())
184-
185- @classmethod
186- def from_fileobj (klass , fileobj ):
187- raise NotImplementedError
188-
189- def write_to (self , fileobj ):
190- raise NotImplementedError
191-
192- def __eq__ (self , other ):
193- return (self .get_data_dtype (), self .get_data_shape (), self .get_zooms ()) == (
194- other .get_data_dtype (),
195- other .get_data_shape (),
196- other .get_zooms (),
197- )
198-
199- def __ne__ (self , other ):
200- return not self == other
223+ if isinstance (header , SpatialProtocol ):
224+ return klass (header .get_data_dtype (), header .get_data_shape (), header .get_zooms ())
225+ return super ().from_header (header )
226+
227+ def __eq__ (self , other : object ) -> bool :
228+ if isinstance (other , SpatialHeader ):
229+ return (self .get_data_dtype (), self .get_data_shape (), self .get_zooms ()) == (
230+ other .get_data_dtype (),
231+ other .get_data_shape (),
232+ other .get_zooms (),
233+ )
234+ return NotImplemented
201235
202- def copy (self ) :
236+ def copy (self : SpatialHdrT ) -> SpatialHdrT :
203237 """Copy object to independent representation
204238
205239 The copy should not be affected by any changes to the original
206240 object.
207241 """
208242 return self .__class__ (self ._dtype , self ._shape , self ._zooms )
209243
210- def get_data_dtype (self ):
244+ def get_data_dtype (self ) -> np . dtype :
211245 return self ._dtype
212246
213- def set_data_dtype (self , dtype ) :
247+ def set_data_dtype (self , dtype : npt . DTypeLike ) -> None :
214248 self ._dtype = np .dtype (dtype )
215249
216- def get_data_shape (self ):
250+ def get_data_shape (self ) -> tuple [ int , ...] :
217251 return self ._shape
218252
219- def set_data_shape (self , shape ) :
253+ def set_data_shape (self , shape : Sequence [ int ]) -> None :
220254 ndim = len (shape )
221255 if ndim == 0 :
222256 self ._shape = (0 ,)
223257 self ._zooms = (1.0 ,)
224258 return
225- self ._shape = tuple ([ int (s ) for s in shape ] )
259+ self ._shape = tuple (int (s ) for s in shape )
226260 # set any unset zooms to 1.0
227261 nzs = min (len (self ._zooms ), ndim )
228262 self ._zooms = self ._zooms [:nzs ] + (1.0 ,) * (ndim - nzs )
229263
230- def get_zooms (self ):
264+ def get_zooms (self ) -> tuple [ float , ...] :
231265 return self ._zooms
232266
233- def set_zooms (self , zooms ) :
234- zooms = tuple ([ float (z ) for z in zooms ] )
267+ def set_zooms (self , zooms : Sequence [ float ]) -> None :
268+ zooms = tuple (float (z ) for z in zooms )
235269 shape = self .get_data_shape ()
236270 ndim = len (shape )
237271 if len (zooms ) != ndim :
238272 raise HeaderDataError ('Expecting %d zoom values for ndim %d' % (ndim , ndim ))
239- if len ([ z for z in zooms if z < 0 ] ):
273+ if any ( z < 0 for z in zooms ):
240274 raise HeaderDataError ('zooms must be positive' )
241275 self ._zooms = zooms
242276
243- def get_base_affine (self ):
277+ def get_base_affine (self ) -> np . ndarray :
244278 shape = self .get_data_shape ()
245279 zooms = self .get_zooms ()
246280 return shape_zoom_affine (shape , zooms , self .default_x_flip )
247281
248282 get_best_affine = get_base_affine
249283
250- def data_to_fileobj (self , data , fileobj , rescale = True ):
284+ def data_to_fileobj (self , data : npt . ArrayLike , fileobj : io . IOBase , rescale : bool = True ):
251285 """Write array data `data` as binary to `fileobj`
252286
253287 Parameters
@@ -264,7 +298,7 @@ def data_to_fileobj(self, data, fileobj, rescale=True):
264298 dtype = self .get_data_dtype ()
265299 fileobj .write (data .astype (dtype ).tobytes (order = self .data_layout ))
266300
267- def data_from_fileobj (self , fileobj ) :
301+ def data_from_fileobj (self , fileobj : io . IOBase ) -> np . ndarray :
268302 """Read binary image data from `fileobj`"""
269303 dtype = self .get_data_dtype ()
270304 shape = self .get_data_shape ()
@@ -274,7 +308,7 @@ def data_from_fileobj(self, fileobj):
274308
275309
276310@cache
277- def _supported_np_types (klass ) :
311+ def _supported_np_types (klass : type [ HasDtype ]) -> set [ type [ np . generic ]] :
278312 """Numpy data types that instances of ``klass`` support
279313
280314 Parameters
@@ -308,7 +342,7 @@ def _supported_np_types(klass):
308342 return supported
309343
310344
311- def supported_np_types (obj ) :
345+ def supported_np_types (obj : HasDtype ) -> set [ type [ np . generic ]] :
312346 """Numpy data types that instance `obj` supports
313347
314348 Parameters
@@ -330,13 +364,15 @@ class ImageDataError(Exception):
330364 pass
331365
332366
333- class SpatialFirstSlicer :
367+ class SpatialFirstSlicer ( ty . Generic [ SpatialImgT ]) :
334368 """Slicing interface that returns a new image with an updated affine
335369
336370 Checks that an image's first three axes are spatial
337371 """
338372
339- def __init__ (self , img ):
373+ img : SpatialImgT
374+
375+ def __init__ (self , img : SpatialImgT ):
340376 # Local import to avoid circular import on module load
341377 from .imageclasses import spatial_axes_first
342378
@@ -346,7 +382,7 @@ def __init__(self, img):
346382 )
347383 self .img = img
348384
349- def __getitem__ (self , slicer ) :
385+ def __getitem__ (self , slicer : object ) -> SpatialImgT :
350386 try :
351387 slicer = self .check_slicing (slicer )
352388 except ValueError as err :
@@ -359,7 +395,7 @@ def __getitem__(self, slicer):
359395 affine = self .slice_affine (slicer )
360396 return self .img .__class__ (dataobj .copy (), affine , self .img .header )
361397
362- def check_slicing (self , slicer , return_spatial = False ):
398+ def check_slicing (self , slicer : object , return_spatial : bool = False ) -> tuple [ slice , ...] :
363399 """Canonicalize slicers and check for scalar indices in spatial dims
364400
365401 Parameters
@@ -376,21 +412,21 @@ def check_slicing(self, slicer, return_spatial=False):
376412 Validated slicer object that will slice image's `dataobj`
377413 without collapsing spatial dimensions
378414 """
379- slicer = canonical_slicers (slicer , self .img .shape )
415+ canonical = canonical_slicers (slicer , self .img .shape )
380416 # We can get away with this because we've checked the image's
381417 # first three axes are spatial.
382418 # More general slicers will need to be smarter, here.
383- spatial_slices = slicer [:3 ]
419+ spatial_slices = canonical [:3 ]
384420 for subslicer in spatial_slices :
385421 if subslicer is None :
386422 raise IndexError ('New axis not permitted in spatial dimensions' )
387423 elif isinstance (subslicer , int ):
388424 raise IndexError (
389425 'Scalar indices disallowed in spatial dimensions; Use `[x]` or `x:x+1`.'
390426 )
391- return spatial_slices if return_spatial else slicer
427+ return spatial_slices if return_spatial else canonical
392428
393- def slice_affine (self , slicer ) :
429+ def slice_affine (self , slicer : tuple [ slice , ...]) -> np . ndarray :
394430 """Retrieve affine for current image, if sliced by a given index
395431
396432 Applies scaling if down-sampling is applied, and adjusts the intercept
@@ -430,10 +466,19 @@ def slice_affine(self, slicer):
430466class SpatialImage (DataobjImage ):
431467 """Template class for volumetric (3D/4D) images"""
432468
433- header_class : Type [SpatialHeader ] = SpatialHeader
434- ImageSlicer = SpatialFirstSlicer
469+ header_class : type [SpatialHeader ] = SpatialHeader
470+ ImageSlicer : type [SpatialFirstSlicer ] = SpatialFirstSlicer
471+
472+ _header : SpatialHeader
435473
436- def __init__ (self , dataobj , affine , header = None , extra = None , file_map = None ):
474+ def __init__ (
475+ self ,
476+ dataobj : ArrayLike ,
477+ affine : np .ndarray ,
478+ header : FileBasedHeader | ty .Mapping | None = None ,
479+ extra : ty .Mapping | None = None ,
480+ file_map : FileMap | None = None ,
481+ ):
437482 """Initialize image
438483
439484 The image is a combination of (array-like, affine matrix, header), with
@@ -483,7 +528,7 @@ def __init__(self, dataobj, affine, header=None, extra=None, file_map=None):
483528 def affine (self ):
484529 return self ._affine
485530
486- def update_header (self ):
531+ def update_header (self ) -> None :
487532 """Harmonize header with image data and affine
488533
489534 >>> data = np.zeros((2,3,4))
@@ -512,7 +557,7 @@ def update_header(self):
512557 return
513558 self ._affine2header ()
514559
515- def _affine2header (self ):
560+ def _affine2header (self ) -> None :
516561 """Unconditionally set affine into the header"""
517562 RZS = self ._affine [:3 , :3 ]
518563 vox = np .sqrt (np .sum (RZS * RZS , axis = 0 ))
@@ -522,7 +567,7 @@ def _affine2header(self):
522567 zooms [:n_to_set ] = vox [:n_to_set ]
523568 hdr .set_zooms (zooms )
524569
525- def __str__ (self ):
570+ def __str__ (self ) -> str :
526571 shape = self .shape
527572 affine = self .affine
528573 return f"""
@@ -534,14 +579,14 @@ def __str__(self):
534579{ self ._header }
535580"""
536581
537- def get_data_dtype (self ):
582+ def get_data_dtype (self ) -> np . dtype :
538583 return self ._header .get_data_dtype ()
539584
540- def set_data_dtype (self , dtype ) :
585+ def set_data_dtype (self , dtype : npt . DTypeLike ) -> None :
541586 self ._header .set_data_dtype (dtype )
542587
543588 @classmethod
544- def from_image (klass , img ) :
589+ def from_image (klass : type [ SpatialImgT ] , img : SpatialImage | FileBasedImage ) -> SpatialImgT :
545590 """Class method to create new instance of own class from `img`
546591
547592 Parameters
@@ -555,15 +600,17 @@ def from_image(klass, img):
555600 cimg : ``spatialimage`` instance
556601 Image, of our own class
557602 """
558- return klass (
559- img .dataobj ,
560- img .affine ,
561- klass .header_class .from_header (img .header ),
562- extra = img .extra .copy (),
563- )
603+ if isinstance (img , SpatialImage ):
604+ return klass (
605+ img .dataobj ,
606+ img .affine ,
607+ klass .header_class .from_header (img .header ),
608+ extra = img .extra .copy (),
609+ )
610+ return super ().from_image (img )
564611
565612 @property
566- def slicer (self ) :
613+ def slicer (self : SpatialImgT ) -> SpatialFirstSlicer [ SpatialImgT ] :
567614 """Slicer object that returns cropped and subsampled images
568615
569616 The image is resliced in the current orientation; no rotation or
@@ -582,7 +629,7 @@ def slicer(self):
582629 """
583630 return self .ImageSlicer (self )
584631
585- def __getitem__ (self , idx ) :
632+ def __getitem__ (self , idx : object ) -> None :
586633 """No slicing or dictionary interface for images
587634
588635 Use the slicer attribute to perform cropping and subsampling at your
@@ -595,7 +642,7 @@ def __getitem__(self, idx):
595642 '`img.get_fdata()[slice]`'
596643 )
597644
598- def orthoview (self ):
645+ def orthoview (self ) -> OrthoSlicer3D :
599646 """Plot the image using OrthoSlicer3D
600647
601648 Returns
@@ -611,7 +658,7 @@ def orthoview(self):
611658 """
612659 return OrthoSlicer3D (self .dataobj , self .affine , title = self .get_filename ())
613660
614- def as_reoriented (self , ornt ) :
661+ def as_reoriented (self : SpatialImgT , ornt : Sequence [ Sequence [ int ]]) -> SpatialImgT :
615662 """Apply an orientation change and return a new image
616663
617664 If ornt is identity transform, return the original image, unchanged
0 commit comments