@@ -162,9 +162,18 @@ def get(self):
162162 return self .o_
163163
164164
165- class WeakInexactType :
166- """Python type representing type of Python real- or
167- complex-valued floating point objects"""
165+ class WeakFloatingType :
166+ """Python type representing type of Python floating point objects"""
167+
168+ def __init__ (self , o ):
169+ self .o_ = o
170+
171+ def get (self ):
172+ return self .o_
173+
174+
175+ class WeakComplexType :
176+ """Python type representing type of Python complex floating point objects"""
168177
169178 def __init__ (self , o ):
170179 self .o_ = o
@@ -189,14 +198,17 @@ def _get_dtype(o, dev):
189198 return WeakBooleanType (o )
190199 if isinstance (o , int ):
191200 return WeakIntegralType (o )
192- if isinstance (o , (float , complex )):
193- return WeakInexactType (o )
201+ if isinstance (o , float ):
202+ return WeakFloatingType (o )
203+ if isinstance (o , complex ):
204+ return WeakComplexType (o )
194205 return np .object_
195206
196207
197208def _validate_dtype (dt ) -> bool :
198209 return isinstance (
199- dt , (WeakBooleanType , WeakInexactType , WeakIntegralType )
210+ dt ,
211+ (WeakBooleanType , WeakIntegralType , WeakFloatingType , WeakComplexType ),
200212 ) or (
201213 isinstance (dt , dpt .dtype )
202214 and dt
@@ -220,22 +232,24 @@ def _validate_dtype(dt) -> bool:
220232
221233
222234def _weak_type_num_kind (o ):
223- _map = {"?" : 0 , "i" : 1 , "f" : 2 }
235+ _map = {"?" : 0 , "i" : 1 , "f" : 2 , "c" : 3 }
224236 if isinstance (o , WeakBooleanType ):
225237 return _map ["?" ]
226238 if isinstance (o , WeakIntegralType ):
227239 return _map ["i" ]
228- if isinstance (o , WeakInexactType ):
240+ if isinstance (o , WeakFloatingType ):
229241 return _map ["f" ]
242+ if isinstance (o , WeakComplexType ):
243+ return _map ["c" ]
230244 raise TypeError (
231245 f"Unexpected type { o } while expecting "
232- "`WeakBooleanType`, `WeakIntegralType`, or "
233- "`WeakInexactType `."
246+ "`WeakBooleanType`, `WeakIntegralType`,"
247+ "`WeakFloatingType`, or `WeakComplexType `."
234248 )
235249
236250
237251def _strong_dtype_num_kind (o ):
238- _map = {"b" : 0 , "i" : 1 , "u" : 1 , "f" : 2 , "c" : 2 }
252+ _map = {"b" : 0 , "i" : 1 , "u" : 1 , "f" : 2 , "c" : 3 }
239253 if not isinstance (o , dpt .dtype ):
240254 raise TypeError
241255 k = o .kind
@@ -247,20 +261,29 @@ def _strong_dtype_num_kind(o):
247261def _resolve_weak_types (o1_dtype , o2_dtype , dev ):
248262 "Resolves weak data type per NEP-0050"
249263 if isinstance (
250- o1_dtype , (WeakBooleanType , WeakInexactType , WeakIntegralType )
264+ o1_dtype ,
265+ (WeakBooleanType , WeakIntegralType , WeakFloatingType , WeakComplexType ),
251266 ):
252267 if isinstance (
253- o2_dtype , (WeakBooleanType , WeakInexactType , WeakIntegralType )
268+ o2_dtype ,
269+ (
270+ WeakBooleanType ,
271+ WeakIntegralType ,
272+ WeakFloatingType ,
273+ WeakComplexType ,
274+ ),
254275 ):
255276 raise ValueError
256277 o1_kind_num = _weak_type_num_kind (o1_dtype )
257278 o2_kind_num = _strong_dtype_num_kind (o2_dtype )
258- if o1_kind_num > o2_kind_num or o1_kind_num == 2 :
279+ if o1_kind_num > o2_kind_num :
259280 if isinstance (o1_dtype , WeakBooleanType ):
260281 return dpt .bool , o2_dtype
261282 if isinstance (o1_dtype , WeakIntegralType ):
262283 return dpt .int64 , o2_dtype
263- if isinstance (o1_dtype .get (), complex ):
284+ if isinstance (o1_dtype , WeakComplexType ):
285+ if o2_dtype is dpt .float16 or o2_dtype is dpt .float32 :
286+ return dpt .complex64 , o2_dtype
264287 return (
265288 _to_device_supported_dtype (dpt .complex128 , dev ),
266289 o2_dtype ,
@@ -269,16 +292,19 @@ def _resolve_weak_types(o1_dtype, o2_dtype, dev):
269292 else :
270293 return o2_dtype , o2_dtype
271294 elif isinstance (
272- o2_dtype , (WeakBooleanType , WeakInexactType , WeakIntegralType )
295+ o2_dtype ,
296+ (WeakBooleanType , WeakIntegralType , WeakFloatingType , WeakComplexType ),
273297 ):
274298 o1_kind_num = _strong_dtype_num_kind (o1_dtype )
275299 o2_kind_num = _weak_type_num_kind (o2_dtype )
276- if o2_kind_num > o1_kind_num or o2_kind_num == 2 :
300+ if o2_kind_num > o1_kind_num :
277301 if isinstance (o2_dtype , WeakBooleanType ):
278302 return o1_dtype , dpt .bool
279303 if isinstance (o2_dtype , WeakIntegralType ):
280304 return o1_dtype , dpt .int64
281- if isinstance (o2_dtype .get (), complex ):
305+ if isinstance (o2_dtype , WeakComplexType ):
306+ if o1_dtype is dpt .float16 or o1_dtype is dpt .float32 :
307+ return o1_dtype , dpt .complex64
282308 return o1_dtype , _to_device_supported_dtype (dpt .complex128 , dev )
283309 return (
284310 o1_dtype ,
0 commit comments