Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/array_api_extra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
one_hot,
pad,
partition,
setdiff1d,
sinc,
)
from ._lib._at import at
Expand All @@ -21,7 +22,6 @@
default_dtype,
kron,
nunique,
setdiff1d,
)
from ._lib._lazy import lazy_apply

Expand Down
53 changes: 53 additions & 0 deletions src/array_api_extra/_delegation.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,59 @@ def pad(
return _funcs.pad(x, pad_width, constant_values=constant_values, xp=xp)


def setdiff1d(
x1: Array | complex,
x2: Array | complex,
/,
*,
assume_unique: bool = False,
xp: ModuleType | None = None,
) -> Array:
"""
Find the set difference of two arrays.

Return the unique values in `x1` that are not in `x2`.

Parameters
----------
x1 : array | int | float | complex | bool
Input array.
x2 : array
Input comparison array.
assume_unique : bool
If ``True``, the input arrays are both assumed to be unique, which
can speed up the calculation. Default is ``False``.
xp : array_namespace, optional
The standard-compatible namespace for `x1` and `x2`. Default: infer.

Returns
-------
array
1D array of values in `x1` that are not in `x2`. The result
is sorted when `assume_unique` is ``False``, but otherwise only sorted
if the input is sorted.

Examples
--------
>>> import array_api_strict as xp
>>> import array_api_extra as xpx

>>> x1 = xp.asarray([1, 2, 3, 2, 4, 1])
>>> x2 = xp.asarray([3, 4, 5, 6])
>>> xpx.setdiff1d(x1, x2, xp=xp)
Array([1, 2], dtype=array_api_strict.int64)
"""

if xp is None:
xp = array_namespace(x1, x2)

if is_numpy_namespace(xp) or is_cupy_namespace(xp) or is_jax_namespace(xp):
x1, x2 = asarrays(x1, x2, xp=xp)
return xp.setdiff1d(x1, x2, assume_unique=assume_unique)

return _funcs.setdiff1d(x1, x2, assume_unique=assume_unique, xp=xp)


def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
r"""
Return the normalized sinc function.
Expand Down
40 changes: 3 additions & 37 deletions src/array_api_extra/_lib/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,44 +715,10 @@ def setdiff1d(
/,
*,
assume_unique: bool = False,
xp: ModuleType | None = None,
) -> Array:
"""
Find the set difference of two arrays.

Return the unique values in `x1` that are not in `x2`.

Parameters
----------
x1 : array | int | float | complex | bool
Input array.
x2 : array
Input comparison array.
assume_unique : bool
If ``True``, the input arrays are both assumed to be unique, which
can speed up the calculation. Default is ``False``.
xp : array_namespace, optional
The standard-compatible namespace for `x1` and `x2`. Default: infer.

Returns
-------
array
1D array of values in `x1` that are not in `x2`. The result
is sorted when `assume_unique` is ``False``, but otherwise only sorted
if the input is sorted.

Examples
--------
>>> import array_api_strict as xp
>>> import array_api_extra as xpx
xp: ModuleType,
) -> Array: # numpydoc ignore=PR01,RT01
"""See docstring in `array_api_extra._delegation.py`."""

>>> x1 = xp.asarray([1, 2, 3, 2, 4, 1])
>>> x2 = xp.asarray([3, 4, 5, 6])
>>> xpx.setdiff1d(x1, x2, xp=xp)
Array([1, 2], dtype=array_api_strict.int64)
"""
if xp is None:
xp = array_namespace(x1, x2)
# https://github.com/microsoft/pyright/issues/10103
x1_, x2_ = asarrays(x1, x2, xp=xp)

Expand Down
12 changes: 9 additions & 3 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,8 @@
)
from array_api_extra._lib._backends import NUMPY_VERSION, Backend
from array_api_extra._lib._testing import xp_assert_close, xp_assert_equal
from array_api_extra._lib._utils._compat import (
device as get_device,
)
from array_api_extra._lib._utils._compat import device as get_device
from array_api_extra._lib._utils._compat import is_jax_namespace
from array_api_extra._lib._utils._helpers import eager_shape, ndindex
from array_api_extra._lib._utils._typing import Array, Device
from array_api_extra.testing import lazy_xp_function
Expand Down Expand Up @@ -1271,6 +1270,10 @@ def test_shapes(
):
x1 = xp.zeros(shape1)
x2 = xp.zeros(shape2)

if is_jax_namespace(xp) and assume_unique and shape1 != (1,):
pytest.xfail(reason="jax#32335 fixed with jax>=0.8.0")

actual = setdiff1d(x1, x2, assume_unique=assume_unique)
xp_assert_equal(actual, xp.empty((0,)))

Expand All @@ -1283,6 +1286,9 @@ def test_python_scalar(self, xp: ModuleType, assume_unique: bool):
actual = setdiff1d(x1, x2, assume_unique=assume_unique)
xp_assert_equal(actual, xp.asarray([1, 2], dtype=xp.int16))

if is_jax_namespace(xp) and assume_unique:
pytest.xfail(reason="jax#32335 fixed with jax>=0.8.0")

actual = setdiff1d(x2, x1, assume_unique=assume_unique)
xp_assert_equal(actual, xp.asarray([], dtype=xp.int16))

Expand Down