From 664f853c8dc49ddc4859b43e47c73334fba5aca8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Gauthier-Clerc?= Date: Thu, 6 Nov 2025 20:02:06 +0100 Subject: [PATCH] fix bad delegation behaviour with atleast_3d and improve atleast_nd unittests. --- src/array_api_extra/_delegation.py | 2 +- tests/test_funcs.py | 189 ++++++++++++++++++++--------- 2 files changed, 132 insertions(+), 59 deletions(-) diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index 289d21e4..4cdf255a 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -67,7 +67,7 @@ def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array if xp is None: xp = array_namespace(x) - if 1 <= ndim <= 3 and ( + if 1 <= ndim <= 2 and ( is_numpy_namespace(xp) or is_jax_namespace(xp) or is_dask_namespace(xp) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index ff050468..9e6b7296 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -54,6 +54,8 @@ lazy_xp_function(setdiff1d, jax_jit=False) lazy_xp_function(sinc) +NestedFloatList = list[float] | list["NestedFloatList"] + class TestApplyWhere: @staticmethod @@ -291,68 +293,139 @@ def test_0D(self, xp: ModuleType): y = atleast_nd(x, ndim=5) xp_assert_equal(y, xp.ones((1, 1, 1, 1, 1))) - def test_1D(self, xp: ModuleType): - x = xp.asarray([0, 1]) - - y = atleast_nd(x, ndim=0) - xp_assert_equal(y, x) - - y = atleast_nd(x, ndim=1) - xp_assert_equal(y, x) - - y = atleast_nd(x, ndim=2) - xp_assert_equal(y, xp.asarray([[0, 1]])) - - y = atleast_nd(x, ndim=5) - xp_assert_equal(y, xp.asarray([[[[[0, 1]]]]])) - - def test_2D(self, xp: ModuleType): - x = xp.asarray([[3.0]]) - - y = atleast_nd(x, ndim=0) - xp_assert_equal(y, x) - - y = atleast_nd(x, ndim=2) - xp_assert_equal(y, x) - - y = atleast_nd(x, ndim=3) - xp_assert_equal(y, 3 * xp.ones((1, 1, 1))) - - y = atleast_nd(x, ndim=5) - xp_assert_equal(y, 3 * xp.ones((1, 1, 1, 1, 1))) - - def test_3D(self, xp: ModuleType): - x = xp.asarray([[[3.0], [2.0]]]) - - y = atleast_nd(x, ndim=0) - xp_assert_equal(y, x) - - y = atleast_nd(x, ndim=2) - xp_assert_equal(y, x) - - y = atleast_nd(x, ndim=3) - xp_assert_equal(y, x) - - y = atleast_nd(x, ndim=5) - xp_assert_equal(y, xp.asarray([[[[[3.0], [2.0]]]]])) - - def test_5D(self, xp: ModuleType): - x = xp.ones((1, 1, 1, 1, 1)) - - y = atleast_nd(x, ndim=0) - xp_assert_equal(y, x) + @pytest.mark.parametrize( + ("x_data", "ndim", "expected_data"), + [ + # --- size-1 vector --- + ([3.0], 0, [3.0]), + ([3.0], 1, [3.0]), + ([3.0], 2, [[3.0]]), + ([3.0], 3, [[[3.0]]]), + ([3.0], 5, [[[[[3.0]]]]]), + # --- size-2 vector --- + ([0.0, 1.0], 0, [0.0, 1.0]), + ([0.0, 1.0], 1, [0.0, 1.0]), + ([0.0, 1.0], 2, [[0.0, 1.0]]), + ([0.0, 1.0], 5, [[[[[0.0, 1.0]]]]]), + ], + ) + def test_1D( + self, + x_data: NestedFloatList, + ndim: int, + expected_data: NestedFloatList, + xp: ModuleType, + ): + x = xp.asarray(x_data) + expected = xp.asarray(expected_data) + y = atleast_nd(x, ndim=ndim) + xp_assert_equal(y, expected) - y = atleast_nd(x, ndim=4) - xp_assert_equal(y, x) + @pytest.mark.parametrize( + ("x_data", "ndim", "expected_data"), + [ + # --- size-1 vector --- + ([[3.0]], 0, [[3.0]]), + ([[3.0]], 1, [[3.0]]), + ([[3.0]], 2, [[3.0]]), + ([[3.0]], 3, [[[3.0]]]), + ([[3.0]], 5, [[[[[3.0]]]]]), + # --- size-2 vector --- + ([[0.0], [1.0]], 0, [[0.0], [1.0]]), + ([[0.0, 1.0]], 1, [[0.0, 1.0]]), + ([[0.0, 1.0]], 2, [[0.0, 1.0]]), + ([[0.0], [1.0]], 3, [[[0.0], [1.0]]]), + ([[0.0, 1.0]], 5, [[[[[0.0, 1.0]]]]]), + ], + ) + def test_2D( + self, + x_data: NestedFloatList, + ndim: int, + expected_data: NestedFloatList, + xp: ModuleType, + ): + x = xp.asarray(x_data) + expected = xp.asarray(expected_data) + y = atleast_nd(x, ndim=ndim) + xp_assert_equal(y, expected) - y = atleast_nd(x, ndim=5) - xp_assert_equal(y, x) + @pytest.mark.parametrize( + ("x_data", "ndim", "expected_data"), + [ + ([[[0.0]], [[1.0]]], 0, [[[0.0]], [[1.0]]]), + ([[[0.0], [1.0]]], 1, [[[0.0], [1.0]]]), + ([[[0.0, 1.0]]], 2, [[[0.0, 1.0]]]), + ([[[0.0]], [[1.0]]], 3, [[[0.0]], [[1.0]]]), + ([[[0.0], [1.0]]], 5, [[[[[0.0], [1.0]]]]]), + ], + ) + def test_3D( + self, + x_data: NestedFloatList, + ndim: int, + expected_data: NestedFloatList, + xp: ModuleType, + ): + x = xp.asarray(x_data) + expected = xp.asarray(expected_data) + y = atleast_nd(x, ndim=ndim) + xp_assert_equal(y, expected) - y = atleast_nd(x, ndim=6) - xp_assert_equal(y, xp.ones((1, 1, 1, 1, 1, 1))) + @pytest.mark.parametrize( + ("x_data", "ndim", "expected_data"), + [ + ([[[[3.0], [2.0]]]], 0, [[[[3.0], [2.0]]]]), + ([[[[3.0, 2.0]]]], 2, [[[[3.0, 2.0]]]]), + ([[[[3.0]], [[2.0]]]], 4, [[[[3.0]], [[2.0]]]]), + ([[[[3.0]]], [[[2.0]]]], 5, [[[[[3.0]]], [[[2.0]]]]]), + ], + ) + def test_4D( + self, + x_data: NestedFloatList, + ndim: int, + expected_data: NestedFloatList, + xp: ModuleType, + ): + x = xp.asarray(x_data) + expected = xp.asarray(expected_data) + y = atleast_nd(x, ndim=ndim) + xp_assert_equal(y, expected) - y = atleast_nd(x, ndim=9) - xp_assert_equal(y, xp.ones((1, 1, 1, 1, 1, 1, 1, 1, 1))) + @pytest.mark.parametrize( + ("x_data", "ndim", "expected_data"), + [ + ([[[[[3.0]], [[2.0]], [[1.0]]]]], 0, [[[[[3.0]], [[2.0]], [[1.0]]]]]), + ([[[[[3.0, 2.0, 6.0]]]]], 2, [[[[[3.0, 2.0, 6.0]]]]]), + ( + [[[[[3.0]]], [[[2.0]]], [[[1.0]]]]], + 4, + [[[[[3.0]]], [[[2.0]]], [[[1.0]]]]], + ), + ( + [[[[[3.0]], [[1.0]]], [[[2.0]], [[1.0]]], [[[1.0]], [[1.0]]]]], + 6, + [[[[[[3.0]], [[1.0]]], [[[2.0]], [[1.0]]], [[[1.0]], [[1.0]]]]]], + ), + ( + [[[[[3.0]], [[1.0]]], [[[2.0]], [[1.0]]], [[[1.0]], [[1.0]]]]], + 9, + [[[[[[[[[3.0]], [[1.0]]], [[[2.0]], [[1.0]]], [[[1.0]], [[1.0]]]]]]]]], + ), + ], + ) + def test_5D( + self, + x_data: NestedFloatList, + ndim: int, + expected_data: NestedFloatList, + xp: ModuleType, + ): + x = xp.asarray(x_data) + expected = xp.asarray(expected_data) + y = atleast_nd(x, ndim=ndim) + xp_assert_equal(y, expected) def test_device(self, xp: ModuleType, device: Device): x = xp.asarray([1, 2, 3], device=device)