55from typing import Any , Dict , NamedTuple , Sequence , Tuple , Union
66from warnings import warn
77
8- from . import api_version
98from . import _array_module as xp
9+ from . import api_version
1010from ._array_module import _UndefinedStub
11+ from ._array_module import mod as _xp
1112from .stubs import name_to_func
1213from .typing import DataType , ScalarType
1314
@@ -88,6 +89,12 @@ def __repr__(self):
8889 return f"EqualityMapping({ self } )"
8990
9091
92+ def _filter_stubs (* args ):
93+ for a in args :
94+ if not isinstance (a , _UndefinedStub ):
95+ yield a
96+
97+
9198_uint_names = ("uint8" , "uint16" , "uint32" , "uint64" )
9299_int_names = ("int8" , "int16" , "int32" , "int64" )
93100_float_names = ("float32" , "float64" )
@@ -113,7 +120,14 @@ def __repr__(self):
113120bool_and_all_int_dtypes = (xp .bool ,) + all_int_dtypes
114121
115122
116- dtype_to_name = EqualityMapping ([(getattr (xp , name ), name ) for name in _dtype_names ])
123+ _dtype_name_pairs = []
124+ for name in _dtype_names :
125+ try :
126+ dtype = getattr (_xp , name )
127+ except AttributeError :
128+ continue
129+ _dtype_name_pairs .append ((dtype , name ))
130+ dtype_to_name = EqualityMapping (_dtype_name_pairs )
117131
118132
119133dtype_to_scalars = EqualityMapping (
@@ -173,12 +187,13 @@ class MinMax(NamedTuple):
173187 ]
174188)
175189
190+
176191dtype_nbits = EqualityMapping (
177- [(d , 8 ) for d in [ xp .int8 , xp .uint8 ] ]
178- + [(d , 16 ) for d in [ xp .int16 , xp .uint16 ] ]
179- + [(d , 32 ) for d in [ xp .int32 , xp .uint32 , xp .float32 ] ]
180- + [(d , 64 ) for d in [ xp .int64 , xp .uint64 , xp .float64 , xp .complex64 ] ]
181- + [(xp . complex128 , 128 )]
192+ [(d , 8 ) for d in _filter_stubs ( xp .int8 , xp .uint8 ) ]
193+ + [(d , 16 ) for d in _filter_stubs ( xp .int16 , xp .uint16 ) ]
194+ + [(d , 32 ) for d in _filter_stubs ( xp .int32 , xp .uint32 , xp .float32 ) ]
195+ + [(d , 64 ) for d in _filter_stubs ( xp .int64 , xp .uint64 , xp .float64 , xp .complex64 ) ]
196+ + [(d , 128 ) for d in _filter_stubs ( xp . complex128 )]
182197)
183198
184199
@@ -265,7 +280,6 @@ class MinMax(NamedTuple):
265280 ((xp .complex64 , xp .complex64 ), xp .complex64 ),
266281 ((xp .complex64 , xp .complex128 ), xp .complex128 ),
267282 ((xp .complex128 , xp .complex128 ), xp .complex128 ),
268-
269283]
270284_numeric_promotions += [((d2 , d1 ), res ) for (d1 , d2 ), res in _numeric_promotions ]
271285_promotion_table = list (set (_numeric_promotions ))
0 commit comments