Skip to content

Commit b9c1503

Browse files
committed
ENH: setdiff1d delegate function
1 parent ebe9a5b commit b9c1503

File tree

3 files changed

+56
-38
lines changed

3 files changed

+56
-38
lines changed

src/array_api_extra/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
one_hot,
1212
pad,
1313
partition,
14+
setdiff1d,
1415
sinc,
1516
)
1617
from ._lib._at import at
@@ -21,7 +22,6 @@
2122
default_dtype,
2223
kron,
2324
nunique,
24-
setdiff1d,
2525
)
2626
from ._lib._lazy import lazy_apply
2727

src/array_api_extra/_delegation.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,58 @@ def pad(
553553
return _funcs.pad(x, pad_width, constant_values=constant_values, xp=xp)
554554

555555

556+
def setdiff1d(
557+
x1: Array | complex,
558+
x2: Array | complex,
559+
/,
560+
*,
561+
assume_unique: bool = False,
562+
xp: ModuleType | None = None,
563+
) -> Array:
564+
"""
565+
Find the set difference of two arrays.
566+
567+
Return the unique values in `x1` that are not in `x2`.
568+
569+
Parameters
570+
----------
571+
x1 : array | int | float | complex | bool
572+
Input array.
573+
x2 : array
574+
Input comparison array.
575+
assume_unique : bool
576+
If ``True``, the input arrays are both assumed to be unique, which
577+
can speed up the calculation. Default is ``False``.
578+
xp : array_namespace, optional
579+
The standard-compatible namespace for `x1` and `x2`. Default: infer.
580+
581+
Returns
582+
-------
583+
array
584+
1D array of values in `x1` that are not in `x2`. The result
585+
is sorted when `assume_unique` is ``False``, but otherwise only sorted
586+
if the input is sorted.
587+
588+
Examples
589+
--------
590+
>>> import array_api_strict as xp
591+
>>> import array_api_extra as xpx
592+
593+
>>> x1 = xp.asarray([1, 2, 3, 2, 4, 1])
594+
>>> x2 = xp.asarray([3, 4, 5, 6])
595+
>>> xpx.setdiff1d(x1, x2, xp=xp)
596+
Array([1, 2], dtype=array_api_strict.int64)
597+
"""
598+
599+
if xp is None:
600+
xp = array_namespace(x1, x2)
601+
602+
if is_numpy_namespace(xp) or is_jax_namespace(xp) or is_cupy_namespace(xp):
603+
return xp.setdiff1d(x1, x2, assume_unique=assume_unique)
604+
605+
return _funcs.setdiff1d(x1, x2, assume_unique=assume_unique, xp=xp)
606+
607+
556608
def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
557609
r"""
558610
Return the normalized sinc function.

src/array_api_extra/_lib/_funcs.py

Lines changed: 3 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -715,44 +715,10 @@ def setdiff1d(
715715
/,
716716
*,
717717
assume_unique: bool = False,
718-
xp: ModuleType | None = None,
719-
) -> Array:
720-
"""
721-
Find the set difference of two arrays.
722-
723-
Return the unique values in `x1` that are not in `x2`.
724-
725-
Parameters
726-
----------
727-
x1 : array | int | float | complex | bool
728-
Input array.
729-
x2 : array
730-
Input comparison array.
731-
assume_unique : bool
732-
If ``True``, the input arrays are both assumed to be unique, which
733-
can speed up the calculation. Default is ``False``.
734-
xp : array_namespace, optional
735-
The standard-compatible namespace for `x1` and `x2`. Default: infer.
736-
737-
Returns
738-
-------
739-
array
740-
1D array of values in `x1` that are not in `x2`. The result
741-
is sorted when `assume_unique` is ``False``, but otherwise only sorted
742-
if the input is sorted.
743-
744-
Examples
745-
--------
746-
>>> import array_api_strict as xp
747-
>>> import array_api_extra as xpx
718+
xp: ModuleType,
719+
) -> Array: # numpydoc ignore=PR01,RT01
720+
"""See docstring in `array_api_extra._delegation.py`."""
748721

749-
>>> x1 = xp.asarray([1, 2, 3, 2, 4, 1])
750-
>>> x2 = xp.asarray([3, 4, 5, 6])
751-
>>> xpx.setdiff1d(x1, x2, xp=xp)
752-
Array([1, 2], dtype=array_api_strict.int64)
753-
"""
754-
if xp is None:
755-
xp = array_namespace(x1, x2)
756722
# https://github.com/microsoft/pyright/issues/10103
757723
x1_, x2_ = asarrays(x1, x2, xp=xp)
758724

0 commit comments

Comments
 (0)