@@ -19,30 +19,49 @@ def non_complex_dtypes():
1919 return xps .boolean_dtypes () | hh .real_dtypes
2020
2121
22+ def numeric_dtypes ():
23+ return xps .boolean_dtypes () | hh .real_dtypes | hh .complex_dtypes
24+
25+
2226def float32 (n : Union [int , float ]) -> float :
2327 return struct .unpack ("!f" , struct .pack ("!f" , float (n )))[0 ]
2428
2529
30+ def _float_match_complex (complex_dtype ):
31+ return xp .float32 if complex_dtype == xp .complex64 else xp .float64
32+
33+
2634@given (
27- x_dtype = non_complex_dtypes (),
28- dtype = non_complex_dtypes (),
35+ x_dtype = numeric_dtypes (),
36+ dtype = numeric_dtypes (),
2937 kw = hh .kwargs (copy = st .booleans ()),
3038 data = st .data (),
3139)
3240def test_astype (x_dtype , dtype , kw , data ):
3341 if xp .bool in (x_dtype , dtype ):
3442 elements_strat = hh .from_dtype (x_dtype )
3543 else :
36- m1 , M1 = dh .dtype_ranges [x_dtype ]
37- m2 , M2 = dh .dtype_ranges [dtype ]
44+
3845 if dh .is_int_dtype (x_dtype ):
3946 cast = int
40- elif x_dtype == xp .float32 :
47+ elif x_dtype in ( xp .float32 , xp . complex64 ) :
4148 cast = float32
4249 else :
4350 cast = float
51+
52+ real_dtype = x_dtype
53+ if x_dtype in (xp .complex64 , xp .complex128 ):
54+ real_dtype = _float_match_complex (x_dtype )
55+ m1 , M1 = dh .dtype_ranges [real_dtype ]
56+
57+ real_dtype = dtype
58+ if dtype in (xp .complex64 , xp .complex128 ):
59+ real_dtype = _float_match_complex (x_dtype )
60+ m2 , M2 = dh .dtype_ranges [real_dtype ]
61+
4462 min_value = cast (max (m1 , m2 ))
4563 max_value = cast (min (M1 , M2 ))
64+
4665 elements_strat = hh .from_dtype (
4766 x_dtype ,
4867 min_value = min_value ,
0 commit comments