Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/array_api_extra/_delegation.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array
if xp is None:
xp = array_namespace(x)

if 1 <= ndim <= 3 and (
if 1 <= ndim <= 2 and (
is_numpy_namespace(xp)
or is_jax_namespace(xp)
or is_dask_namespace(xp)
Expand Down
189 changes: 131 additions & 58 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@
lazy_xp_function(setdiff1d, jax_jit=False)
lazy_xp_function(sinc)

NestedFloatList = list[float] | list["NestedFloatList"]


class TestApplyWhere:
@staticmethod
Expand Down Expand Up @@ -291,68 +293,139 @@ def test_0D(self, xp: ModuleType):
y = atleast_nd(x, ndim=5)
xp_assert_equal(y, xp.ones((1, 1, 1, 1, 1)))

def test_1D(self, xp: ModuleType):
x = xp.asarray([0, 1])

y = atleast_nd(x, ndim=0)
xp_assert_equal(y, x)

y = atleast_nd(x, ndim=1)
xp_assert_equal(y, x)

y = atleast_nd(x, ndim=2)
xp_assert_equal(y, xp.asarray([[0, 1]]))

y = atleast_nd(x, ndim=5)
xp_assert_equal(y, xp.asarray([[[[[0, 1]]]]]))

def test_2D(self, xp: ModuleType):
x = xp.asarray([[3.0]])

y = atleast_nd(x, ndim=0)
xp_assert_equal(y, x)

y = atleast_nd(x, ndim=2)
xp_assert_equal(y, x)

y = atleast_nd(x, ndim=3)
xp_assert_equal(y, 3 * xp.ones((1, 1, 1)))

y = atleast_nd(x, ndim=5)
xp_assert_equal(y, 3 * xp.ones((1, 1, 1, 1, 1)))

def test_3D(self, xp: ModuleType):
x = xp.asarray([[[3.0], [2.0]]])

y = atleast_nd(x, ndim=0)
xp_assert_equal(y, x)

y = atleast_nd(x, ndim=2)
xp_assert_equal(y, x)

y = atleast_nd(x, ndim=3)
xp_assert_equal(y, x)

y = atleast_nd(x, ndim=5)
xp_assert_equal(y, xp.asarray([[[[[3.0], [2.0]]]]]))

def test_5D(self, xp: ModuleType):
x = xp.ones((1, 1, 1, 1, 1))

y = atleast_nd(x, ndim=0)
xp_assert_equal(y, x)
@pytest.mark.parametrize(
("x_data", "ndim", "expected_data"),
[
# --- size-1 vector ---
([3.0], 0, [3.0]),
([3.0], 1, [3.0]),
([3.0], 2, [[3.0]]),
([3.0], 3, [[[3.0]]]),
([3.0], 5, [[[[[3.0]]]]]),
# --- size-2 vector ---
([0.0, 1.0], 0, [0.0, 1.0]),
([0.0, 1.0], 1, [0.0, 1.0]),
([0.0, 1.0], 2, [[0.0, 1.0]]),
([0.0, 1.0], 5, [[[[[0.0, 1.0]]]]]),
],
)
def test_1D(
self,
x_data: NestedFloatList,
ndim: int,
expected_data: NestedFloatList,
xp: ModuleType,
):
x = xp.asarray(x_data)
expected = xp.asarray(expected_data)
y = atleast_nd(x, ndim=ndim)
xp_assert_equal(y, expected)

y = atleast_nd(x, ndim=4)
xp_assert_equal(y, x)
@pytest.mark.parametrize(
("x_data", "ndim", "expected_data"),
[
# --- size-1 vector ---
([[3.0]], 0, [[3.0]]),
([[3.0]], 1, [[3.0]]),
([[3.0]], 2, [[3.0]]),
([[3.0]], 3, [[[3.0]]]),
([[3.0]], 5, [[[[[3.0]]]]]),
# --- size-2 vector ---
([[0.0], [1.0]], 0, [[0.0], [1.0]]),
([[0.0, 1.0]], 1, [[0.0, 1.0]]),
([[0.0, 1.0]], 2, [[0.0, 1.0]]),
([[0.0], [1.0]], 3, [[[0.0], [1.0]]]),
([[0.0, 1.0]], 5, [[[[[0.0, 1.0]]]]]),
],
)
def test_2D(
self,
x_data: NestedFloatList,
ndim: int,
expected_data: NestedFloatList,
xp: ModuleType,
):
x = xp.asarray(x_data)
expected = xp.asarray(expected_data)
y = atleast_nd(x, ndim=ndim)
xp_assert_equal(y, expected)

y = atleast_nd(x, ndim=5)
xp_assert_equal(y, x)
@pytest.mark.parametrize(
("x_data", "ndim", "expected_data"),
[
([[[0.0]], [[1.0]]], 0, [[[0.0]], [[1.0]]]),
([[[0.0], [1.0]]], 1, [[[0.0], [1.0]]]),
([[[0.0, 1.0]]], 2, [[[0.0, 1.0]]]),
([[[0.0]], [[1.0]]], 3, [[[0.0]], [[1.0]]]),
([[[0.0], [1.0]]], 5, [[[[[0.0], [1.0]]]]]),
],
)
def test_3D(
self,
x_data: NestedFloatList,
ndim: int,
expected_data: NestedFloatList,
xp: ModuleType,
):
x = xp.asarray(x_data)
expected = xp.asarray(expected_data)
y = atleast_nd(x, ndim=ndim)
xp_assert_equal(y, expected)

y = atleast_nd(x, ndim=6)
xp_assert_equal(y, xp.ones((1, 1, 1, 1, 1, 1)))
@pytest.mark.parametrize(
("x_data", "ndim", "expected_data"),
[
([[[[3.0], [2.0]]]], 0, [[[[3.0], [2.0]]]]),
([[[[3.0, 2.0]]]], 2, [[[[3.0, 2.0]]]]),
([[[[3.0]], [[2.0]]]], 4, [[[[3.0]], [[2.0]]]]),
([[[[3.0]]], [[[2.0]]]], 5, [[[[[3.0]]], [[[2.0]]]]]),
],
)
def test_4D(
self,
x_data: NestedFloatList,
ndim: int,
expected_data: NestedFloatList,
xp: ModuleType,
):
x = xp.asarray(x_data)
expected = xp.asarray(expected_data)
y = atleast_nd(x, ndim=ndim)
xp_assert_equal(y, expected)

y = atleast_nd(x, ndim=9)
xp_assert_equal(y, xp.ones((1, 1, 1, 1, 1, 1, 1, 1, 1)))
@pytest.mark.parametrize(
("x_data", "ndim", "expected_data"),
[
([[[[[3.0]], [[2.0]], [[1.0]]]]], 0, [[[[[3.0]], [[2.0]], [[1.0]]]]]),
([[[[[3.0, 2.0, 6.0]]]]], 2, [[[[[3.0, 2.0, 6.0]]]]]),
(
[[[[[3.0]]], [[[2.0]]], [[[1.0]]]]],
4,
[[[[[3.0]]], [[[2.0]]], [[[1.0]]]]],
),
(
[[[[[3.0]], [[1.0]]], [[[2.0]], [[1.0]]], [[[1.0]], [[1.0]]]]],
6,
[[[[[[3.0]], [[1.0]]], [[[2.0]], [[1.0]]], [[[1.0]], [[1.0]]]]]],
),
(
[[[[[3.0]], [[1.0]]], [[[2.0]], [[1.0]]], [[[1.0]], [[1.0]]]]],
9,
[[[[[[[[[3.0]], [[1.0]]], [[[2.0]], [[1.0]]], [[[1.0]], [[1.0]]]]]]]]],
),
],
)
def test_5D(
self,
x_data: NestedFloatList,
ndim: int,
expected_data: NestedFloatList,
xp: ModuleType,
):
x = xp.asarray(x_data)
expected = xp.asarray(expected_data)
y = atleast_nd(x, ndim=ndim)
xp_assert_equal(y, expected)

def test_device(self, xp: ModuleType, device: Device):
x = xp.asarray([1, 2, 3], device=device)
Expand Down