1212required, but we don't yet have a clean way to disable only those tests (see https://github.com/data-apis/array-api-tests/issues/25).
1313
1414"""
15- # TODO: test with complex dtypes where appropriate
16-
1715import pytest
1816from hypothesis import assume , given
1917from hypothesis .strategies import (booleans , composite , tuples , floats ,
2422import itertools
2523
2624from .array_helpers import assert_exactly_equal , asarray
27- from .hypothesis_helpers import (arrays , xps , shapes , kwargs , matrix_shapes ,
28- square_matrix_shapes , symmetric_matrices ,
25+ from .hypothesis_helpers import (arrays , all_floating_dtypes , xps , shapes ,
26+ kwargs , matrix_shapes , square_matrix_shapes ,
27+ symmetric_matrices ,
2928 positive_definite_matrices , MAX_ARRAY_SIZE ,
3029 invertible_matrices , two_mutual_arrays ,
3130 mutually_promotable_dtypes , one_d_shapes ,
@@ -210,7 +209,7 @@ def exact_cross(a, b):
210209
211210@pytest .mark .xp_extension ('linalg' )
212211@given (
213- x = arrays (dtype = xps . floating_dtypes (), shape = square_matrix_shapes ),
212+ x = arrays (dtype = all_floating_dtypes (), shape = square_matrix_shapes ),
214213)
215214def test_det (x ):
216215 res = linalg .det (x )
@@ -224,7 +223,7 @@ def test_det(x):
224223
225224@pytest .mark .xp_extension ('linalg' )
226225@given (
227- x = arrays (dtype = xps .real_dtypes (), shape = matrix_shapes ()),
226+ x = arrays (dtype = xps .scalar_dtypes (), shape = matrix_shapes ()),
228227 # offset may produce an overflow if it is too large. Supporting offsets
229228 # that are way larger than the array shape isn't very important.
230229 kw = kwargs (offset = integers (- MAX_ARRAY_SIZE , MAX_ARRAY_SIZE ))
@@ -382,7 +381,7 @@ def test_matrix_norm(x, kw):
382381@given (
383382 # Generate any square matrix if n >= 0 but only invertible matrices if n < 0
384383 x = matrix_power_n .flatmap (lambda n : invertible_matrices () if n < 0 else
385- arrays (dtype = xps . floating_dtypes (),
384+ arrays (dtype = all_floating_dtypes (),
386385 shape = square_matrix_shapes )),
387386 n = matrix_power_n ,
388387)
@@ -409,7 +408,7 @@ def test_matrix_rank(x, kw):
409408 linalg .matrix_rank (x , ** kw )
410409
411410@given (
412- x = arrays (dtype = xps .real_dtypes (), shape = matrix_shapes ()),
411+ x = arrays (dtype = xps .scalar_dtypes (), shape = matrix_shapes ()),
413412)
414413def test_matrix_transpose (x ):
415414 res = _array_module .matrix_transpose (x )
@@ -459,7 +458,7 @@ def test_pinv(x, kw):
459458
460459@pytest .mark .xp_extension ('linalg' )
461460@given (
462- x = arrays (dtype = xps . floating_dtypes (), shape = matrix_shapes ()),
461+ x = arrays (dtype = all_floating_dtypes (), shape = matrix_shapes ()),
463462 kw = kwargs (mode = sampled_from (['reduced' , 'complete' ]))
464463)
465464def test_qr (x , kw ):
@@ -495,7 +494,7 @@ def test_qr(x, kw):
495494
496495@pytest .mark .xp_extension ('linalg' )
497496@given (
498- x = arrays (dtype = xps . floating_dtypes (), shape = square_matrix_shapes ),
497+ x = arrays (dtype = all_floating_dtypes (), shape = square_matrix_shapes ),
499498)
500499def test_slogdet (x ):
501500 res = linalg .slogdet (x )
@@ -504,11 +503,16 @@ def test_slogdet(x):
504503
505504 sign , logabsdet = res
506505
507- assert sign .dtype == x .dtype , "slogdet().sign did not return the correct dtype"
508- assert sign .shape == x .shape [:- 2 ], "slogdet().sign did not return the correct shape"
509- assert logabsdet .dtype == x .dtype , "slogdet().logabsdet did not return the correct dtype"
510- assert logabsdet .shape == x .shape [:- 2 ], "slogdet().logabsdet did not return the correct shape"
511-
506+ ph .assert_dtype ("slogdet" , in_dtype = x .dtype , out_dtype = sign .dtype ,
507+ expected = x .dtype , repr_name = "sign.dtype" )
508+ ph .assert_shape ("slogdet" , out_shape = sign .shape , expected = x .shape [:- 2 ],
509+ repr_name = "sign.shape" )
510+ expected_dtype = dh .as_real_dtype (x .dtype )
511+ ph .assert_dtype ("slogdet" , in_dtype = x .dtype , out_dtype = logabsdet .dtype ,
512+ expected = expected_dtype , repr_name = "logabsdet.dtype" )
513+ ph .assert_shape ("slogdet" , out_shape = logabsdet .shape ,
514+ expected = x .shape [:- 2 ],
515+ repr_name = "logabsdet.shape" )
512516
513517 _test_stacks (lambda x : linalg .slogdet (x ).sign , x ,
514518 res = sign , dims = 0 )
@@ -550,7 +554,7 @@ def _x2_shapes(draw):
550554 return draw (stack_shapes )[1 ] + draw (x1 ).shape [- 1 :] + (end ,)
551555
552556 x2_shapes = one_of (x1 .map (lambda x : (x .shape [- 1 ],)), _x2_shapes ())
553- x2 = arrays (dtype = xps . floating_dtypes (), shape = x2_shapes )
557+ x2 = arrays (dtype = all_floating_dtypes (), shape = x2_shapes )
554558 return x1 , x2
555559
556560@pytest .mark .xp_extension ('linalg' )
@@ -734,7 +738,7 @@ def test_tensordot(x1, x2, kw):
734738
735739@pytest .mark .xp_extension ('linalg' )
736740@given (
737- x = arrays (dtype = xps .real_dtypes (), shape = matrix_shapes ()),
741+ x = arrays (dtype = xps .numeric_dtypes (), shape = matrix_shapes ()),
738742 # offset may produce an overflow if it is too large. Supporting offsets
739743 # that are way larger than the array shape isn't very important.
740744 kw = kwargs (offset = integers (- MAX_ARRAY_SIZE , MAX_ARRAY_SIZE ))
@@ -812,7 +816,7 @@ def true_val(x, y, axis=-1):
812816
813817@pytest .mark .xp_extension ('linalg' )
814818@given (
815- x = arrays (dtype = xps . floating_dtypes (), shape = shapes (min_side = 1 )),
819+ x = arrays (dtype = all_floating_dtypes (), shape = shapes (min_side = 1 )),
816820 data = data (),
817821)
818822def test_vector_norm (x , data ):
@@ -838,8 +842,9 @@ def test_vector_norm(x, data):
838842 ph .assert_keepdimable_shape ('linalg.vector_norm' , out_shape = res .shape ,
839843 in_shape = x .shape , axes = _axes ,
840844 keepdims = keepdims , kw = kw )
845+ expected_dtype = dh .as_real_dtype (x .dtype )
841846 ph .assert_dtype ('linalg.vector_norm' , in_dtype = x .dtype ,
842- out_dtype = res .dtype )
847+ out_dtype = res .dtype , expected = expected_dtype )
843848
844849 _kw = kw .copy ()
845850 _kw .pop ('axis' , None )
0 commit comments