Skip to content

Commit e64813e

Browse files
committed
add xfail according to jax issue
1 parent b9c1503 commit e64813e

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

src/array_api_extra/_delegation.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -599,8 +599,10 @@ def setdiff1d(
599599
if xp is None:
600600
xp = array_namespace(x1, x2)
601601

602-
if is_numpy_namespace(xp) or is_jax_namespace(xp) or is_cupy_namespace(xp):
603-
return xp.setdiff1d(x1, x2, assume_unique=assume_unique)
602+
if is_numpy_namespace(xp) or is_cupy_namespace(xp) or is_jax_namespace(xp):
603+
# https://github.com/microsoft/pyright/issues/10103
604+
x1_, x2_ = asarrays(x1, x2, xp=xp)
605+
return xp.setdiff1d(x1_, x2_, assume_unique=assume_unique)
604606

605607
return _funcs.setdiff1d(x1, x2, assume_unique=assume_unique, xp=xp)
606608

tests/test_funcs.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,7 @@
3434
)
3535
from array_api_extra._lib._backends import NUMPY_VERSION, Backend
3636
from array_api_extra._lib._testing import xp_assert_close, xp_assert_equal
37-
from array_api_extra._lib._utils._compat import (
38-
device as get_device,
39-
)
37+
from array_api_extra._lib._utils._compat import device as get_device, is_jax_namespace
4038
from array_api_extra._lib._utils._helpers import eager_shape, ndindex
4139
from array_api_extra._lib._utils._typing import Array, Device
4240
from array_api_extra.testing import lazy_xp_function
@@ -1271,6 +1269,10 @@ def test_shapes(
12711269
):
12721270
x1 = xp.zeros(shape1)
12731271
x2 = xp.zeros(shape2)
1272+
1273+
if is_jax_namespace(xp) and assume_unique and shape1!=(1,):
1274+
pytest.xfail(reason="jax#32335 fixed with jax>=0.8.0")
1275+
12741276
actual = setdiff1d(x1, x2, assume_unique=assume_unique)
12751277
xp_assert_equal(actual, xp.empty((0,)))
12761278

@@ -1283,6 +1285,9 @@ def test_python_scalar(self, xp: ModuleType, assume_unique: bool):
12831285
actual = setdiff1d(x1, x2, assume_unique=assume_unique)
12841286
xp_assert_equal(actual, xp.asarray([1, 2], dtype=xp.int16))
12851287

1288+
if is_jax_namespace(xp) and assume_unique:
1289+
pytest.xfail(reason="jax#32335 fixed with jax>=0.8.0")
1290+
12861291
actual = setdiff1d(x2, x1, assume_unique=assume_unique)
12871292
xp_assert_equal(actual, xp.asarray([], dtype=xp.int16))
12881293

0 commit comments

Comments
 (0)