Skip to content

Commit e96ecf7

Browse files
committed
RF: Use np.sctypesDict to source scalar types
np.sctypes does not have a consistent value type, and does not enumerate all scalar types of a given kind.
1 parent 12db9ec commit e96ecf7

File tree

4 files changed

+30
-24
lines changed

4 files changed

+30
-24
lines changed

nibabel/spatialimages.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -284,19 +284,17 @@ def supported_np_types(obj):
284284
set of numpy types that `obj` supports
285285
"""
286286
dt = obj.get_data_dtype()
287-
supported = []
288-
for name, np_types in np.sctypes.items():
289-
for np_type in np_types:
290-
try:
291-
obj.set_data_dtype(np_type)
292-
except HeaderDataError:
293-
continue
294-
# Did set work?
295-
if np.dtype(obj.get_data_dtype()) == np.dtype(np_type):
296-
supported.append(np_type)
297-
# Reset original header dtype
287+
supported = set()
288+
for np_type in set(np.sctypeDict.values()):
289+
try:
290+
obj.set_data_dtype(np_type)
291+
except HeaderDataError:
292+
continue
293+
# Did set work?
294+
if np.dtype(obj.get_data_dtype()) == np.dtype(np_type):
295+
supported.add(np_type)
298296
obj.set_data_dtype(dt)
299-
return set(supported)
297+
return supported
300298

301299

302300
class ImageDataError(Exception):

nibabel/tests/test_analyze.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,20 +49,20 @@
4949
PIXDIM0_MSG = 'pixdim[1,2,3] should be non-zero; setting 0 dims to 1'
5050

5151

52-
def add_intp(supported_np_types):
53-
# Add intp, uintp to supported types as necessary
54-
supported_dtypes = [np.dtype(t) for t in supported_np_types]
55-
for np_type in (np.intp, np.uintp):
56-
if np.dtype(np_type) in supported_dtypes:
57-
supported_np_types.add(np_type)
52+
def add_duplicate_types(supported_np_types):
53+
# Update supported numpy types with named scalar types that map to the same set of dtypes
54+
dtypes = {np.dtype(t) for t in supported_np_types}
55+
supported_np_types.update(
56+
scalar for scalar in set(np.sctypeDict.values()) if np.dtype(scalar) in dtypes
57+
)
5858

5959

6060
class TestAnalyzeHeader(tws._TestLabeledWrapStruct):
6161
header_class = AnalyzeHeader
6262
example_file = header_file
6363
sizeof_hdr = AnalyzeHeader.sizeof_hdr
6464
supported_np_types = {np.uint8, np.int16, np.int32, np.float32, np.float64, np.complex64}
65-
add_intp(supported_np_types)
65+
add_duplicate_types(supported_np_types)
6666

6767
def test_supported_types(self):
6868
hdr = self.header_class()

nibabel/tests/test_nifti1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ class TestNifti1PairHeader(tana.TestAnalyzeHeader, tspm.HeaderScalingMixin):
8080
)
8181
if have_binary128():
8282
supported_np_types = supported_np_types.union((np.longdouble, np.longcomplex))
83-
tana.add_intp(supported_np_types)
83+
tana.add_duplicate_types(supported_np_types)
8484

8585
def test_empty(self):
8686
tana.TestAnalyzeHeader.test_empty(self)

nibabel/tests/test_spm99analyze.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,18 @@
3535
from ..volumeutils import _dt_min_max, apply_read_scaling
3636
from . import test_analyze
3737

38-
FLOAT_TYPES = np.sctypes['float']
39-
COMPLEX_TYPES = np.sctypes['complex']
40-
INT_TYPES = np.sctypes['int']
41-
UINT_TYPES = np.sctypes['uint']
38+
# np.sctypes values are lists of types with unique sizes
39+
# For testing, we want all concrete classes of a type
40+
# Key on kind, rather than abstract base classes, since timedelta64 is a signedinteger
41+
sctypes = {}
42+
for sctype in set(np.sctypeDict.values()):
43+
sctypes.setdefault(np.dtype(sctype).kind, []).append(sctype)
44+
45+
# Sort types to ensure that xdist doesn't complain about test order when we parametrize
46+
FLOAT_TYPES = sorted(sctypes['f'], key=lambda x: x.__name__)
47+
COMPLEX_TYPES = sorted(sctypes['c'], key=lambda x: x.__name__)
48+
INT_TYPES = sorted(sctypes['i'], key=lambda x: x.__name__)
49+
UINT_TYPES = sorted(sctypes['u'], key=lambda x: x.__name__)
4250
CFLOAT_TYPES = FLOAT_TYPES + COMPLEX_TYPES
4351
IUINT_TYPES = INT_TYPES + UINT_TYPES
4452
NUMERIC_TYPES = CFLOAT_TYPES + IUINT_TYPES

0 commit comments

Comments
 (0)