Skip to content

Commit 7bd70ef

Browse files
committed
ENH: searchsorted: allow python scalars for x2
cross-ref data-apis/array-api#982
1 parent 8088d82 commit 7bd70ef

File tree

1 file changed

+45
-0
lines changed

1 file changed

+45
-0
lines changed

array_api_tests/test_searching_functions.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,3 +285,48 @@ def test_searchsorted(data):
285285
except Exception as exc:
286286
exc.add_note(repro_snippet)
287287
raise
288+
289+
290+
### @pytest.mark.min_version("2025.12")
291+
@given(data=st.data())
292+
def test_searchsorted_with_scalars(data):
293+
# 1. draw x1, sorter and side exactly the same as in test_searchsorted
294+
x1_dtype = data.draw(st.sampled_from(dh.real_dtypes))
295+
_x1 = data.draw(
296+
st.lists(
297+
xps.from_dtype(x1_dtype, allow_nan=False, allow_infinity=False),
298+
min_size=1,
299+
unique=True
300+
),
301+
label="_x1",
302+
)
303+
x1 = xp.asarray(_x1, dtype=x1_dtype)
304+
if data.draw(st.booleans(), label="use sorter?"):
305+
sorter = xp.argsort(x1)
306+
else:
307+
sorter = None
308+
x1 = xp.sort(x1)
309+
310+
kw = data.draw(hh.kwargs(side=st.sampled_from(["left", "right"])))
311+
312+
# 2. draw x2, a real-valued scalar (IOW, an int or a float)
313+
x2 = data.draw(hh.scalars(st.sampled_from([xp.int32, xp.float64]), finite=True))
314+
315+
# 3. testing: similar to test_searchsorted, modulo `out.shape == ()`
316+
repro_snippet = ph.format_snippet(
317+
f"xp.searchsorted({x1!r}, {x2!r}, sorter={sorter!r}, **kw) with {kw = }"
318+
)
319+
try:
320+
out = xp.searchsorted(x1, x2, sorter=sorter, **kw)
321+
322+
ph.assert_dtype(
323+
"searchsorted",
324+
in_dtype=[x1.dtype], #, x2.dtype
325+
out_dtype=out.dtype,
326+
expected=xp.__array_namespace_info__().default_dtypes()["indexing"],
327+
)
328+
# TODO: values testing
329+
ph.assert_shape("searchsorted", out_shape=out.shape, expected=())
330+
except Exception as exc:
331+
exc.add_note(repro_snippet)
332+
raise

0 commit comments

Comments
 (0)