22from typing import Union
33
44import pytest
5- from hypothesis import given
5+ from hypothesis import given , assume
66from hypothesis import strategies as st
77
88from . import _array_module as xp
@@ -34,6 +34,8 @@ def _float_match_complex(complex_dtype):
3434 data = st .data (),
3535)
3636def test_astype (x_dtype , dtype , kw , data ):
37+ _complex_dtypes = (xp .complex64 , xp .complex128 )
38+
3739 if xp .bool in (x_dtype , dtype ):
3840 elements_strat = hh .from_dtype (x_dtype )
3941 else :
@@ -46,12 +48,12 @@ def test_astype(x_dtype, dtype, kw, data):
4648 cast = float
4749
4850 real_dtype = x_dtype
49- if x_dtype in ( xp . complex64 , xp . complex128 ) :
51+ if x_dtype in _complex_dtypes :
5052 real_dtype = _float_match_complex (x_dtype )
5153 m1 , M1 = dh .dtype_ranges [real_dtype ]
5254
5355 real_dtype = dtype
54- if dtype in ( xp . complex64 , xp . complex128 ) :
56+ if dtype in _complex_dtypes :
5557 real_dtype = _float_match_complex (x_dtype )
5658 m2 , M2 = dh .dtype_ranges [real_dtype ]
5759
@@ -69,6 +71,11 @@ def test_astype(x_dtype, dtype, kw, data):
6971 hh .arrays (dtype = x_dtype , shape = hh .shapes (), elements = elements_strat ), label = "x"
7072 )
7173
74+ # according to the spec, "Casting a complex floating-point array to a real-valued
75+ # data type should not be permitted."
76+ # https://data-apis.org/array-api/latest/API_specification/generated/array_api.astype.html#astype
77+ assume (not ((x_dtype in _complex_dtypes ) and (dtype not in _complex_dtypes )))
78+
7279 out = xp .astype (x , dtype , ** kw )
7380
7481 ph .assert_kw_dtype ("astype" , kw_dtype = dtype , out_dtype = out .dtype )
0 commit comments