Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/array_api_extra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
cov,
expand_dims,
isclose,
isin,
nan_to_num,
one_hot,
pad,
Expand Down Expand Up @@ -39,6 +40,7 @@
"default_dtype",
"expand_dims",
"isclose",
"isin",
"kron",
"lazy_apply",
"nan_to_num",
Expand Down
59 changes: 59 additions & 0 deletions src/array_api_extra/_delegation.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,3 +836,62 @@ def argpartition(
# kth is not small compared to x.size

return _funcs.argpartition(a, kth, axis=axis, xp=xp)


def isin(
a: Array,
b: Array,
/,
*,
assume_unique: bool = False,
invert: bool = False,
kind: str | None = None,
xp: ModuleType | None = None,
) -> Array:
"""
Determine whether each element in `a` is present in `b`.

Return a boolean array of the same shape as `a` that is True for elements
that are in `b` and False otherwise.

Parameters
----------
a : array_like
Input elements.
b : array_like
The elements against which to test each element of `a`.
assume_unique : bool, optional
If True, the input arrays are both assumed to be unique which can speed
up the calculation. Default: False.
invert : bool, optional
If True, the values in the returned array are inverted. Default: False.
kind : str | None, optional
The algorithm or method to use. This will not affect the final result,
but will affect the speed and memory use.
For Numpy the options are {None, "sort", "table"}.
For Jax the mapped parameter is instead `method` and the options are
{"compare_all", "binary_search", "sort", and "auto" (default)}
For Cupy, Dask, Torch and the default case this parameter is not present and
thus ignored. Default: None.
xp : array_namespace, optional
The standard-compatible namespace for `a` and `b`. Default: infer.

Returns
-------
array
An array having the same shape as that of `a` that is True for elements
that are in `b` and False otherwise.
"""
if xp is None:
xp = array_namespace(a, b)

if is_numpy_namespace(xp):
return xp.isin(a, b, assume_unique=assume_unique, invert=invert, kind=kind)
if is_jax_namespace(xp):
if kind is None:
kind = "auto"
return xp.isin(a, b, assume_unique=assume_unique, invert=invert, method=kind)
if is_cupy_namespace(xp) or is_torch_namespace(xp) or is_dask_namespace(xp):
return xp.isin(a, b, assume_unique=assume_unique, invert=invert)

return _funcs.isin(a, b, assume_unique=assume_unique, invert=invert, xp=xp)
22 changes: 22 additions & 0 deletions src/array_api_extra/_lib/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,3 +801,25 @@ def argpartition( # numpydoc ignore=PR01,RT01
) -> Array:
"""See docstring in `array_api_extra._delegation.py`."""
return xp.argsort(x, axis=axis, stable=False)


def isin( # numpydoc ignore=PR01,RT01
a: Array,
b: Array,
/,
*,
assume_unique: bool = False,
invert: bool = False,
xp: ModuleType | None = None,
) -> Array:
"""See docstring in `array_api_extra._delegation.py`."""
if xp is None:
xp = array_namespace(a, b)

original_a_shape = a.shape
a = xp.reshape(a, (-1,))
b = xp.reshape(b, (-1,))
return xp.reshape(
_helpers.in1d(a, b, assume_unique=assume_unique, invert=invert, xp=xp),
original_a_shape,
)
13 changes: 12 additions & 1 deletion tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
default_dtype,
expand_dims,
isclose,
isin,
kron,
nan_to_num,
nunique,
Expand Down Expand Up @@ -888,7 +889,7 @@ def test_device(self, xp: ModuleType, device: Device, equal_nan: bool):
b = xp.asarray([1e-9, 1e-4, xp.nan], device=device)
res = isclose(a, b, equal_nan=equal_nan)
assert get_device(res) == device

def test_array_on_device_with_scalar(self, xp: ModuleType, device: Device):
a = xp.asarray([0.01, 0.5, 0.8, 0.9, 1.00001], device=device)
b = 1
Expand Down Expand Up @@ -1476,3 +1477,13 @@ def test_nd(self, xp: ModuleType, ndim: int):
@override
def test_input_validation(self, xp: ModuleType):
self._test_input_validation(xp)


@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no unique_inverse")
class TestIsIn:
def test_simple(self, xp: ModuleType):
a = xp.asarray([[0, 2], [4, 6]])
b = xp.asarray([1, 2, 3, 4])
expected = xp.asarray([[False, True], [True, False]])
res = isin(a, b)
xp_assert_equal(res, expected)
Loading