@@ -19,18 +19,10 @@ def non_complex_dtypes():
1919 return xps .boolean_dtypes () | hh .real_dtypes
2020
2121
22- def numeric_dtypes ():
23- return xps .boolean_dtypes () | hh .real_dtypes | hh .complex_dtypes
24-
25-
2622def float32 (n : Union [int , float ]) -> float :
2723 return struct .unpack ("!f" , struct .pack ("!f" , float (n )))[0 ]
2824
2925
30- def _float_match_complex (complex_dtype ):
31- return xp .float32 if complex_dtype == xp .complex64 else xp .float64
32-
33-
3426@given (
3527 x_dtype = non_complex_dtypes (),
3628 dtype = non_complex_dtypes (),
@@ -115,46 +107,23 @@ def test_broadcast_to(x, data):
115107 # TODO: test values
116108
117109
118- @given (_from = numeric_dtypes (), to = numeric_dtypes (), data = st .data ())
119- def test_can_cast (_from , to , data ):
120- from_ = data .draw (
121- st .just (_from ) | hh .arrays (dtype = _from , shape = hh .shapes ()), label = "from_"
122- )
110+ @given (_from = hh .all_dtypes , to = hh .all_dtypes )
111+ def test_can_cast (_from , to ):
112+ out = xp .can_cast (_from , to )
123113
124- out = xp .can_cast (from_ , to )
114+ expected = False
115+ for other in dh .all_dtypes :
116+ if dh .promotion_table .get ((_from , other )) == to :
117+ expected = True
118+ break
125119
126120 f_func = f"[can_cast({ dh .dtype_to_name [_from ]} , { dh .dtype_to_name [to ]} )]"
127- assert isinstance (out , bool ), f"{ type (out )= } , but should be bool { f_func } "
128- if _from == xp .bool :
129- expected = to == xp .bool
130- else :
131- same_family = None
132- for dtypes in [dh .all_int_dtypes , dh .real_float_dtypes , dh .complex_dtypes ]:
133- if _from in dtypes :
134- same_family = to in dtypes
135- break
136- assert same_family is not None # sanity check
137- if same_family :
138- from_dtype = (_float_match_complex (_from )
139- if _from in (xp .complex64 , xp .complex128 )
140- else _from )
141- to_dtype = (_float_match_complex (to )
142- if to in (xp .complex64 , xp .complex128 )
143- else to )
144-
145- from_min , from_max = dh .dtype_ranges [from_dtype ]
146- to_min , to_max = dh .dtype_ranges [to_dtype ]
147- expected = from_min >= to_min and from_max <= to_max
148- else :
149- expected = False
150121 if expected :
151122 # cross-kind casting is not explicitly disallowed. We can only test
152- # the cases where it should return True. TODO: if expected=False,
153- # check that the array library actually allows such casts.
123+ # the cases where it should return True.
154124 assert out == expected , f"{ out = } , but should be { expected } { f_func } "
155125
156126
157-
158127@pytest .mark .parametrize ("dtype" , dh .real_float_dtypes )
159128def test_finfo (dtype ):
160129 out = xp .finfo (dtype )
0 commit comments