@@ -92,7 +92,7 @@ def shape(self):
9292class ImageGrid (SampledSpatialData ):
9393 """Class to represent spaces of gridded data (images)."""
9494
95- __slots__ = ["_affine" , "_inverse" , "_ndindex" ]
95+ __slots__ = ["_affine" , "_inverse" , "_ndindex" , "_header" ]
9696
9797 def __init__ (self , image ):
9898 """Create a gridded sampling reference."""
@@ -101,6 +101,7 @@ def __init__(self, image):
101101
102102 self ._affine = image .affine
103103 self ._shape = image .shape
104+ self ._header = getattr (image , "header" , None )
104105
105106 self ._ndim = getattr (image , "ndim" , len (image .shape ))
106107 if self ._ndim >= 4 :
@@ -117,6 +118,11 @@ def affine(self):
117118 """Access the indexes-to-RAS affine."""
118119 return self ._affine
119120
121+ @property
122+ def header (self ):
123+ """Access the original reference's header."""
124+ return self ._header
125+
120126 @property
121127 def inverse (self ):
122128 """Access the RAS-to-indexes affine."""
@@ -293,12 +299,15 @@ def apply(
293299 )
294300
295301 if isinstance (_ref , ImageGrid ): # If reference is grid, reshape
302+ hdr = None
303+ if _ref .header is not None :
304+ hdr = _ref .header .copy ()
305+ hdr .set_data_dtype (output_dtype )
296306 moved = spatialimage .__class__ (
297307 resampled .reshape (_ref .shape ).astype (output_dtype ),
298308 _ref .affine ,
299- spatialimage . header
309+ hdr ,
300310 )
301- moved .set_data_dtype (output_dtype )
302311 return moved
303312
304313 return resampled
0 commit comments