@@ -898,26 +898,28 @@ def set_data_dtype(self, datatype):
898898 >>> hdr.set_data_dtype(np.dtype(np.uint8))
899899 >>> hdr.get_data_dtype()
900900 dtype('uint8')
901- >>> hdr.set_data_dtype('implausible') #doctest: +IGNORE_EXCEPTION_DETAIL
901+ >>> hdr.set_data_dtype('implausible')
902902 Traceback (most recent call last):
903903 ...
904- HeaderDataError: data dtype "implausible" not recognized
905- >>> hdr.set_data_dtype('none') #doctest: +IGNORE_EXCEPTION_DETAIL
904+ nibabel.spatialimages. HeaderDataError: data dtype "implausible" not recognized
905+ >>> hdr.set_data_dtype('none')
906906 Traceback (most recent call last):
907907 ...
908- HeaderDataError: data dtype "none" known but not supported
909- >>> hdr.set_data_dtype(np.void) #doctest: +IGNORE_EXCEPTION_DETAIL
908+ nibabel.spatialimages. HeaderDataError: data dtype "none" known but not supported
909+ >>> hdr.set_data_dtype(np.void)
910910 Traceback (most recent call last):
911911 ...
912- HeaderDataError: data dtype "<type 'numpy.void'>" known but not supported
913- >>> hdr.set_data_dtype('int') #doctest: +IGNORE_EXCEPTION_DETAIL
912+ nibabel.spatialimages.HeaderDataError: data dtype "<class 'numpy.void'>" known
913+ but not supported
914+ >>> hdr.set_data_dtype('int')
914915 Traceback (most recent call last):
915916 ...
916917 ValueError: Invalid data type 'int'. Specify a sized integer, e.g., 'uint8' or numpy.int16.
917- >>> hdr.set_data_dtype(int) #doctest: +IGNORE_EXCEPTION_DETAIL
918+ >>> hdr.set_data_dtype(int)
918919 Traceback (most recent call last):
919920 ...
920- ValueError: Invalid data type 'int'. Specify a sized integer, e.g., 'uint8' or numpy.int16.
921+ ValueError: Invalid data type <class 'int'>. Specify a sized integer, e.g., 'uint8' or
922+ numpy.int16.
921923 >>> hdr.set_data_dtype('int64')
922924 >>> hdr.get_data_dtype() == np.dtype('int64')
923925 True
@@ -1799,6 +1801,10 @@ class Nifti1Pair(analyze.AnalyzeImage):
17991801 _meta_sniff_len = header_class .sizeof_hdr
18001802 rw = True
18011803
1804+ # If a _dtype_alias has been set, it can only be resolved by inspecting
1805+ # the data at serialization time
1806+ _dtype_alias = None
1807+
18021808 def __init__ (self , dataobj , affine , header = None ,
18031809 extra = None , file_map = None , dtype = None ):
18041810 # Special carve-out for 64 bit integers
@@ -2043,6 +2049,137 @@ def set_sform(self, affine, code=None, **kwargs):
20432049 else :
20442050 self ._affine [:] = self ._header .get_best_affine ()
20452051
2052+ def set_data_dtype (self , datatype ):
2053+ """ Set numpy dtype for data from code, dtype, type or alias
2054+
2055+ Using :py:class:`int` or ``"int"`` is disallowed, as these types
2056+ will be interpreted as ``np.int64``, which is almost never desired.
2057+ ``np.int64`` is permitted for those intent on making poor choices.
2058+
2059+ The following aliases are defined to allow for flexible specification:
2060+
2061+ * ``'mask'`` - Alias for ``uint8``
2062+ * ``'compat'`` - The nearest Analyze-compatible datatype
2063+ (``uint8``, ``int16``, ``int32``, ``float32``)
2064+ * ``'smallest'`` - The smallest Analyze-compatible integer
2065+ (``uint8``, ``int16``, ``int32``)
2066+
2067+ Dynamic aliases are resolved when ``get_data_dtype()`` is called
2068+ with a ``finalize=True`` flag. Until then, these aliases are not
2069+ written to the header and will not persist to new images.
2070+
2071+ Examples
2072+ --------
2073+ >>> ints = np.arange(24, dtype='i4').reshape((2,3,4))
2074+
2075+ >>> img = Nifti1Image(ints, np.eye(4))
2076+ >>> img.set_data_dtype(np.uint8)
2077+ >>> img.get_data_dtype()
2078+ dtype('uint8')
2079+ >>> img.set_data_dtype('mask')
2080+ >>> img.get_data_dtype()
2081+ dtype('uint8')
2082+ >>> img.set_data_dtype('compat')
2083+ >>> img.get_data_dtype()
2084+ 'compat'
2085+ >>> img.get_data_dtype(finalize=True)
2086+ dtype('<i4')
2087+ >>> img.get_data_dtype()
2088+ dtype('<i4')
2089+ >>> img.set_data_dtype('smallest')
2090+ >>> img.get_data_dtype()
2091+ 'smallest'
2092+ >>> img.get_data_dtype(finalize=True)
2093+ dtype('uint8')
2094+ >>> img.get_data_dtype()
2095+ dtype('uint8')
2096+
2097+ Note that floating point values will not be coerced to ``int``
2098+
2099+ >>> floats = np.arange(24, dtype='f4').reshape((2,3,4))
2100+ >>> img = Nifti1Image(floats, np.eye(4))
2101+ >>> img.set_data_dtype('smallest')
2102+ >>> img.get_data_dtype(finalize=True)
2103+ Traceback (most recent call last):
2104+ ...
2105+ ValueError: Cannot automatically cast array (of type float32) to an integer
2106+ type with fewer than 64 bits. Please set_data_dtype() to an explicit data type.
2107+
2108+ >>> arr = np.arange(1000, 1024, dtype='i4').reshape((2,3,4))
2109+ >>> img = Nifti1Image(arr, np.eye(4))
2110+ >>> img.set_data_dtype('smallest')
2111+ >>> img.set_data_dtype('implausible')
2112+ Traceback (most recent call last):
2113+ ...
2114+ nibabel.spatialimages.HeaderDataError: data dtype "implausible" not recognized
2115+ >>> img.set_data_dtype('none')
2116+ Traceback (most recent call last):
2117+ ...
2118+ nibabel.spatialimages.HeaderDataError: data dtype "none" known but not supported
2119+ >>> img.set_data_dtype(np.void)
2120+ Traceback (most recent call last):
2121+ ...
2122+ nibabel.spatialimages.HeaderDataError: data dtype "<class 'numpy.void'>" known
2123+ but not supported
2124+ >>> img.set_data_dtype('int')
2125+ Traceback (most recent call last):
2126+ ...
2127+ ValueError: Invalid data type 'int'. Specify a sized integer, e.g., 'uint8' or numpy.int16.
2128+ >>> img.set_data_dtype(int)
2129+ Traceback (most recent call last):
2130+ ...
2131+ ValueError: Invalid data type <class 'int'>. Specify a sized integer, e.g., 'uint8' or
2132+ numpy.int16.
2133+ >>> img.set_data_dtype('int64')
2134+ >>> img.get_data_dtype() == np.dtype('int64')
2135+ True
2136+ """
2137+ # Comparing dtypes to strings, numpy will attempt to call, e.g., dtype('mask'),
2138+ # so only check for aliases if the type is a string
2139+ # See https://github.com/numpy/numpy/issues/7242
2140+ if isinstance (datatype , str ):
2141+ # Static aliases
2142+ if datatype == 'mask' :
2143+ datatype = 'u1'
2144+ # Dynamic aliases
2145+ elif datatype in ('compat' , 'smallest' ):
2146+ self ._dtype_alias = datatype
2147+ return
2148+
2149+ self ._dtype_alias = None
2150+ super ().set_data_dtype (datatype )
2151+
2152+ def get_data_dtype (self , finalize = False ):
2153+ """ Get numpy dtype for data
2154+
2155+ If ``set_data_dtype()`` has been called with an alias
2156+ and ``finalize`` is ``False``, return the alias.
2157+ If ``finalize`` is ``True``, determine the appropriate dtype
2158+ from the image data object and set the final dtype in the
2159+ header before returning it.
2160+ """
2161+ if self ._dtype_alias is None :
2162+ return super ().get_data_dtype ()
2163+ if not finalize :
2164+ return self ._dtype_alias
2165+
2166+ datatype = None
2167+ if self ._dtype_alias == 'compat' :
2168+ datatype = _get_analyze_compat_dtype (self ._dataobj )
2169+ descrip = "an Analyze-compatible dtype"
2170+ elif self ._dtype_alias == 'smallest' :
2171+ datatype = _get_smallest_dtype (self ._dataobj )
2172+ descrip = "an integer type with fewer than 64 bits"
2173+ else :
2174+ raise ValueError (f"Unknown dtype alias { self ._dtype_alias } ." )
2175+ if datatype is None :
2176+ dt = get_obj_dtype (self ._dataobj )
2177+ raise ValueError (f"Cannot automatically cast array (of type { dt } ) to { descrip } ."
2178+ " Please set_data_dtype() to an explicit data type." )
2179+
2180+ self .set_data_dtype (datatype ) # Clears the alias
2181+ return super ().get_data_dtype ()
2182+
20462183 def as_reoriented (self , ornt ):
20472184 """Apply an orientation change and return a new image
20482185
@@ -2136,3 +2273,141 @@ def save(img, filename):
21362273 Nifti1Image .instance_to_filename (img , filename )
21372274 except ImageFileError :
21382275 Nifti1Pair .instance_to_filename (img , filename )
2276+
2277+
2278+ def _get_smallest_dtype (
2279+ arr ,
2280+ itypes = (np .uint8 , np .int16 , np .int32 ),
2281+ ftypes = (),
2282+ ):
2283+ """ Return the smallest "sensible" dtype that will hold the array data
2284+
2285+ The purpose of this function is to support automatic type selection
2286+ for serialization, so "sensible" here means well-supported in the NIfTI-1 world.
2287+
2288+ For floating point data, select between single- and double-precision.
2289+ For integer data, select among uint8, int16 and int32.
2290+
2291+ The test is for min/max range, so float64 is pretty unlikely to be hit.
2292+
2293+ Returns ``None`` if these dtypes do not suffice.
2294+
2295+ >>> _get_smallest_dtype(np.array([0, 1]))
2296+ dtype('uint8')
2297+ >>> _get_smallest_dtype(np.array([-1, 1]))
2298+ dtype('int16')
2299+ >>> _get_smallest_dtype(np.array([0, 256]))
2300+ dtype('int16')
2301+ >>> _get_smallest_dtype(np.array([-65536, 65536]))
2302+ dtype('int32')
2303+ >>> _get_smallest_dtype(np.array([-2147483648, 2147483648]))
2304+
2305+ By default floating point types are not searched:
2306+
2307+ >>> _get_smallest_dtype(np.array([1.]))
2308+ >>> _get_smallest_dtype(np.array([2. ** 1000]))
2309+ >>> _get_smallest_dtype(np.longdouble(2) ** 2000)
2310+ >>> _get_smallest_dtype(np.array([1+0j]))
2311+
2312+ However, this function can be passed "legal" floating point types, and
2313+ the logic works the same.
2314+
2315+ >>> _get_smallest_dtype(np.array([1.]), ftypes=('float32',))
2316+ dtype('float32')
2317+ >>> _get_smallest_dtype(np.array([2. ** 1000]), ftypes=('float32',))
2318+ >>> _get_smallest_dtype(np.longdouble(2) ** 2000, ftypes=('float32',))
2319+ >>> _get_smallest_dtype(np.array([1+0j]), ftypes=('float32',))
2320+ """
2321+ arr = np .asanyarray (arr )
2322+ if np .issubdtype (arr .dtype , np .floating ):
2323+ test_dts = ftypes
2324+ info = np .finfo
2325+ elif np .issubdtype (arr .dtype , np .integer ):
2326+ test_dts = itypes
2327+ info = np .iinfo
2328+ else :
2329+ return None
2330+
2331+ mn , mx = np .min (arr ), np .max (arr )
2332+ for dt in test_dts :
2333+ dtinfo = info (dt )
2334+ if dtinfo .min <= mn and mx <= dtinfo .max :
2335+ return np .dtype (dt )
2336+
2337+
2338+ def _get_analyze_compat_dtype (arr ):
2339+ """ Return an Analyze-compatible dtype that ``arr`` can be safely cast to
2340+
2341+ Analyze-compatible types are returned without inspection:
2342+
2343+ >>> _get_analyze_compat_dtype(np.uint8([0, 1]))
2344+ dtype('uint8')
2345+ >>> _get_analyze_compat_dtype(np.int16([0, 1]))
2346+ dtype('int16')
2347+ >>> _get_analyze_compat_dtype(np.int32([0, 1]))
2348+ dtype('int32')
2349+ >>> _get_analyze_compat_dtype(np.float32([0, 1]))
2350+ dtype('float32')
2351+
2352+ Signed ``int8`` are cast to ``uint8`` or ``int16`` based on value ranges:
2353+
2354+ >>> _get_analyze_compat_dtype(np.int8([0, 1]))
2355+ dtype('uint8')
2356+ >>> _get_analyze_compat_dtype(np.int8([-1, 1]))
2357+ dtype('int16')
2358+
2359+ Unsigned ``uint16`` are cast to ``int16`` or ``int32`` based on value ranges:
2360+
2361+ >>> _get_analyze_compat_dtype(np.uint16([32767]))
2362+ dtype('int16')
2363+ >>> _get_analyze_compat_dtype(np.uint16([65535]))
2364+ dtype('int32')
2365+
2366+ ``int32`` is returned for integer types and ``float32`` for floating point types:
2367+
2368+ >>> _get_analyze_compat_dtype(np.array([-1, 1]))
2369+ dtype('int32')
2370+ >>> _get_analyze_compat_dtype(np.array([-1., 1.]))
2371+ dtype('float32')
2372+
2373+ If the value ranges exceed 4 bytes or cannot be cast, then a ``ValueError`` is raised:
2374+
2375+ >>> _get_analyze_compat_dtype(np.array([0, 4294967295]))
2376+ Traceback (most recent call last):
2377+ ...
2378+ ValueError: Cannot find analyze-compatible dtype for array with dtype=int64
2379+ (min=0, max=4294967295)
2380+
2381+ >>> _get_analyze_compat_dtype([0., 2.e40])
2382+ Traceback (most recent call last):
2383+ ...
2384+ ValueError: Cannot find analyze-compatible dtype for array with dtype=float64
2385+ (min=0.0, max=2e+40)
2386+
2387+ Note that real-valued complex arrays cannot be safely cast.
2388+
2389+ >>> _get_analyze_compat_dtype(np.array([1+0j]))
2390+ Traceback (most recent call last):
2391+ ...
2392+ ValueError: Cannot find analyze-compatible dtype for array with dtype=complex128
2393+ (min=(1+0j), max=(1+0j))
2394+ """
2395+ arr = np .asanyarray (arr )
2396+ dtype = arr .dtype
2397+ if dtype in (np .uint8 , np .int16 , np .int32 , np .float32 ):
2398+ return dtype
2399+
2400+ if dtype == np .int8 :
2401+ return np .dtype ('uint8' if arr .min () >= 0 else 'int16' )
2402+ elif dtype == np .uint16 :
2403+ return np .dtype ('int16' if arr .max () <= np .iinfo (np .int16 ).max else 'int32' )
2404+
2405+ mn , mx = arr .min (), arr .max ()
2406+ if np .can_cast (mn , np .int32 ) and np .can_cast (mx , np .int32 ):
2407+ return np .dtype ('int32' )
2408+ if np .can_cast (mn , np .float32 ) and np .can_cast (mx , np .float32 ):
2409+ return np .dtype ('float32' )
2410+
2411+ raise ValueError (
2412+ f"Cannot find analyze-compatible dtype for array with dtype={ dtype } (min={ mn } , max={ mx } )"
2413+ )
0 commit comments