|
45 | 45 |
|
46 | 46 | def assert_equal(x, y, msg_extra=None): |
47 | 47 | extra = '' if not msg_extra else f' ({msg_extra})' |
48 | | - if x.dtype in dh.float_dtypes: |
| 48 | + if x.dtype in dh.all_float_dtypes: |
49 | 49 | # It's too difficult to do an approximately equal test here because |
50 | 50 | # different routines can give completely different answers, and even |
51 | 51 | # when it does work, the elementwise comparisons are too slow. So for |
@@ -701,7 +701,8 @@ def test_tensordot(x1, x2, kw): |
701 | 701 | # TODO: vary shapes, vary contracted axes, test different axes arguments |
702 | 702 | res = xp.tensordot(x1, x2, **kw) |
703 | 703 |
|
704 | | - ph.assert_dtype("tensordot", [x1.dtype, x2.dtype], res.dtype) |
| 704 | + ph.assert_dtype("tensordot", in_dtype=[x1.dtype, x2.dtype], |
| 705 | + out_dtype=res.dtype) |
705 | 706 |
|
706 | 707 | axes = _axes = kw.get('axes', 2) |
707 | 708 |
|
@@ -785,9 +786,10 @@ def test_vecdot(x1, x2, kw): |
785 | 786 |
|
786 | 787 | res = xp.vecdot(x1, x2, **kw) |
787 | 788 |
|
788 | | - ph.assert_dtype("vecdot", [x1.dtype, x2.dtype], res.dtype) |
| 789 | + ph.assert_dtype("vecdot", in_dtype=[x1.dtype, x2.dtype], |
| 790 | + out_dtype=res.dtype) |
789 | 791 | # TODO: assert shape and elements |
790 | | - ph.assert_shape("vecdot", res.shape, expected_shape) |
| 792 | + ph.assert_shape("vecdot", out_shape=res.shape, expected=expected_shape) |
791 | 793 |
|
792 | 794 | if x1.dtype in dh.int_dtypes: |
793 | 795 | def true_val(x, y, axis=-1): |
@@ -827,9 +829,11 @@ def test_vector_norm(x, data): |
827 | 829 |
|
828 | 830 | _axes = sh.normalise_axis(axis, x.ndim) |
829 | 831 |
|
830 | | - ph.assert_keepdimable_shape('linalg.vector_norm', res.shape, x.shape, |
831 | | - _axes, keepdims, **kw) |
832 | | - ph.assert_dtype('linalg.vector_norm', x.dtype, res.dtype) |
| 832 | + ph.assert_keepdimable_shape('linalg.vector_norm', out_shape=res.shape, |
| 833 | + in_shape=x.shape, axes=_axes, |
| 834 | + keepdims=keepdims, kw=kw) |
| 835 | + ph.assert_dtype('linalg.vector_norm', in_dtype=x.dtype, |
| 836 | + out_dtype=res.dtype) |
833 | 837 |
|
834 | 838 | _kw = kw.copy() |
835 | 839 | _kw.pop('axis', None) |
|
0 commit comments