1616pytestmark = pytest .mark .ci
1717
1818
19+ # TODO: test with complex dtypes
20+ def non_complex_dtypes ():
21+ return xps .boolean_dtypes () | xps .real_dtypes ()
22+
23+
1924def float32 (n : Union [int , float ]) -> float :
2025 return struct .unpack ("!f" , struct .pack ("!f" , float (n )))[0 ]
2126
2227
2328@given (
24- x_dtype = xps . scalar_dtypes (),
25- dtype = xps . scalar_dtypes (),
29+ x_dtype = non_complex_dtypes (),
30+ dtype = non_complex_dtypes (),
2631 kw = hh .kwargs (copy = st .booleans ()),
2732 data = st .data (),
2833)
@@ -101,7 +106,7 @@ def test_broadcast_to(x, data):
101106 # TODO: test values
102107
103108
104- @given (_from = xps . scalar_dtypes (), to = xps . scalar_dtypes (), data = st .data ())
109+ @given (_from = non_complex_dtypes (), to = non_complex_dtypes (), data = st .data ())
105110def test_can_cast (_from , to , data ):
106111 from_ = data .draw (
107112 st .just (_from ) | xps .arrays (dtype = _from , shape = hh .shapes ()), label = "from_"
@@ -114,10 +119,12 @@ def test_can_cast(_from, to, data):
114119 if _from == xp .bool :
115120 expected = to == xp .bool
116121 else :
117- for dtypes in [dh .all_int_dtypes , dh .float_dtypes ]:
122+ same_family = None
123+ for dtypes in [dh .all_int_dtypes , dh .float_dtypes , dh .complex_dtypes ]:
118124 if _from in dtypes :
119125 same_family = to in dtypes
120126 break
127+ assert same_family is not None # sanity check
121128 if same_family :
122129 from_min , from_max = dh .dtype_ranges [_from ]
123130 to_min , to_max = dh .dtype_ranges [to ]
0 commit comments