Skip to content

Commit 8088d82

Browse files
committed
ENH: searchsorted: draw x1.dtype
1 parent 73228d1 commit 8088d82

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

array_api_tests/test_searching_functions.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -244,11 +244,12 @@ def test_where(shapes, dtypes, data):
244244
@given(data=st.data())
245245
def test_searchsorted(data):
246246
# TODO: Allow different dtypes for x1 and x2
247+
x1_dtype = data.draw(st.sampled_from(dh.real_dtypes))
247248
_x1 = data.draw(
248-
st.lists(xps.from_dtype(dh.default_float), min_size=1, unique=True),
249+
st.lists(xps.from_dtype(x1_dtype), min_size=1, unique=True),
249250
label="_x1",
250251
)
251-
x1 = xp.asarray(_x1, dtype=dh.default_float)
252+
x1 = xp.asarray(_x1, dtype=x1_dtype)
252253
if data.draw(st.booleans(), label="use sorter?"):
253254
sorter = xp.argsort(x1)
254255
else:
@@ -258,7 +259,7 @@ def test_searchsorted(data):
258259

259260
x2 = data.draw(
260261
st.lists(st.sampled_from(_x1), unique=True, min_size=1).map(
261-
lambda o: xp.asarray(o, dtype=dh.default_float)
262+
lambda o: xp.asarray(o, dtype=x1_dtype)
262263
),
263264
label="x2",
264265
)

0 commit comments

Comments
 (0)