@@ -46,7 +46,21 @@ def _isdtype_impl(dtype, kind):
4646 elif isinstance (kind , tuple ):
4747 return any (_isdtype_impl (dtype , k ) for k in kind )
4848 else :
49- raise TypeError (f"Unsupported data type kind: { kind } " )
49+ raise TypeError (f"Unsupported type for dtype kind: { type (kind )} " )
50+
51+
52+ def _get_device_impl (d ):
53+ if d is None :
54+ return dpctl .select_default_device ()
55+ elif isinstance (d , dpctl .SyclDevice ):
56+ return d
57+ elif isinstance (d , (dpt .Device , dpctl .SyclQueue )):
58+ return d .sycl_device
59+ else :
60+ try :
61+ return dpctl .SyclDevice (d )
62+ except TypeError :
63+ raise TypeError (f"Unsupported type for device argument: { type (d )} " )
5064
5165
5266__array_api_version__ = "2023.12"
@@ -117,13 +131,13 @@ def default_dtypes(self, *, device=None):
117131 Returns a dictionary of default data types for ``device``.
118132
119133 Args:
120- device (Optional[:class:`dpctl.SyclDevice`, :class:`dpctl.SyclQueue`, :class:`dpctl.tensor.Device`]):
134+ device (Optional[:class:`dpctl.SyclDevice`, :class:`dpctl.SyclQueue`, :class:`dpctl.tensor.Device`, str ]):
121135 array API concept of device used in getting default data types.
122136 ``device`` can be ``None`` (in which case the default device
123- is used), an instance of :class:`dpctl.SyclDevice` corresponding
124- to a non-partitioned SYCL device, an instance of
125- :class:`dpctl.SyclQueue`, or a :class :`dpctl.tensor.Device`
126- object returned by :attr:`dpctl.tensor.usm_ndarray.device` .
137+ is used), an instance of :class:`dpctl.SyclDevice`, an instance
138+ of :class:`dpctl.SyclQueue`, a :class:`dpctl.tensor.Device`
139+ object returned by :attr :`dpctl.tensor.usm_ndarray.device`, or
140+ a filter selector string .
127141 Default: ``None``.
128142
129143 Returns:
@@ -135,10 +149,7 @@ def default_dtypes(self, *, device=None):
135149 - ``"integral"``: dtype
136150 - ``"indexing"``: dtype
137151 """
138- if device is None :
139- device = dpctl .select_default_device ()
140- elif isinstance (device , dpt .Device ):
141- device = device .sycl_device
152+ device = _get_device_impl (device )
142153 return {
143154 "real floating" : dpt .dtype (default_device_fp_type (device )),
144155 "complex floating" : dpt .dtype (default_device_complex_type (device )),
@@ -161,10 +172,10 @@ def dtypes(self, *, device=None, kind=None):
161172 device (Optional[:class:`dpctl.SyclDevice`, :class:`dpctl.SyclQueue`, :class:`dpctl.tensor.Device`, str]):
162173 array API concept of device used in getting default data types.
163174 ``device`` can be ``None`` (in which case the default device is
164- used), an instance of :class:`dpctl.SyclDevice` corresponding
165- to a non-partitioned SYCL device, an instance of
166- :class:`dpctl.SyclQueue`, or a :class :`dpctl.tensor.Device`
167- object returned by :attr:`dpctl.tensor.usm_ndarray.device` .
175+ used), an instance of :class:`dpctl.SyclDevice`, an instance of
176+ :class:`dpctl.SyclQueue`, a :class:`dpctl.tensor.Device`
177+ object returned by :attr :`dpctl.tensor.usm_ndarray.device`, or
178+ a filter selector string .
168179 Default: ``None``.
169180
170181 kind (Optional[str, Tuple[str, ...]]):
@@ -196,22 +207,20 @@ def dtypes(self, *, device=None, kind=None):
196207 a dictionary of the supported data types of the specified
197208 ``kind``
198209 """
199- if device is None :
200- device = dpctl .select_default_device ()
201- elif isinstance (device , dpt .Device ):
202- device = device .sycl_device
210+ device = _get_device_impl (device )
203211 _fp64 = device .has_aspect_fp64
204212 if kind is None :
205213 return {
206214 key : val
207215 for key , val in self ._all_dtypes .items ()
208- if (key != "float64" or _fp64 )
216+ if _fp64 or (key != "float64" and key != "complex128" )
209217 }
210218 else :
211219 return {
212220 key : val
213221 for key , val in self ._all_dtypes .items ()
214- if (key != "float64" or _fp64 ) and _isdtype_impl (val , kind )
222+ if (_fp64 or (key != "float64" and key != "complex128" ))
223+ and _isdtype_impl (val , kind )
215224 }
216225
217226 def devices (self ):
0 commit comments