|
23 | 23 | from warnings import warn |
24 | 24 |
|
25 | 25 | import pytest |
26 | | -from hypothesis import assume, given, note |
| 26 | +from hypothesis import given, note, settings |
27 | 27 | from hypothesis import strategies as st |
28 | 28 |
|
29 | 29 | from array_api_tests.typing import Array, DataType |
30 | 30 |
|
31 | 31 | from . import dtype_helpers as dh |
32 | 32 | from . import hypothesis_helpers as hh |
33 | 33 | from . import pytest_helpers as ph |
34 | | -from . import shape_helpers as sh |
35 | 34 | from . import xp, xps |
36 | 35 | from .stubs import category_to_funcs |
37 | 36 |
|
@@ -1210,143 +1209,57 @@ def parse_binary_case_block(case_block: str) -> List[BinaryCase]: |
1210 | 1209 | assert len(iop_params) != 0 |
1211 | 1210 |
|
1212 | 1211 |
|
1213 | | -@pytest.mark.unvectorized |
1214 | 1212 | @pytest.mark.parametrize("func_name, func, case", unary_params) |
1215 | | -@given( |
1216 | | - x=hh.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes(min_side=1)), |
1217 | | - data=st.data(), |
1218 | | -) |
1219 | | -def test_unary(func_name, func, case, x, data): |
1220 | | - set_idx = data.draw( |
1221 | | - xps.indices(x.shape, max_dims=0, allow_ellipsis=False), label="set idx" |
| 1213 | +def test_unary(func_name, func, case): |
| 1214 | + in_value = case.cond_from_dtype(xp.float64).example() |
| 1215 | + x = xp.asarray(in_value, dtype=xp.float64) |
| 1216 | + out = func(x) |
| 1217 | + out_value = float(out) |
| 1218 | + assert case.check_result(in_value, out_value), ( |
| 1219 | + f"out={out_value}, but should be {case.result_expr} [{func_name}()]\n" |
1222 | 1220 | ) |
1223 | | - set_value = data.draw(case.cond_from_dtype(x.dtype), label="set value") |
1224 | | - x[set_idx] = set_value |
1225 | | - note(f"{x=}") |
1226 | | - |
1227 | | - res = func(x) |
1228 | | - |
1229 | | - good_example = False |
1230 | | - for idx in sh.ndindex(res.shape): |
1231 | | - in_ = float(x[idx]) |
1232 | | - if case.cond(in_): |
1233 | | - good_example = True |
1234 | | - out = float(res[idx]) |
1235 | | - f_in = f"{sh.fmt_idx('x', idx)}={in_}" |
1236 | | - f_out = f"{sh.fmt_idx('out', idx)}={out}" |
1237 | | - assert case.check_result(in_, out), ( |
1238 | | - f"{f_out}, but should be {case.result_expr} [{func_name}()]\n" |
1239 | | - f"condition: {case.cond_expr}\n" |
1240 | | - f"{f_in}" |
1241 | | - ) |
1242 | | - break |
1243 | | - assume(good_example) |
1244 | | - |
1245 | | - |
1246 | | -x1_strat, x2_strat = hh.two_mutual_arrays( |
1247 | | - dtypes=dh.real_float_dtypes, |
1248 | | - two_shapes=hh.mutually_broadcastable_shapes(2, min_side=1), |
1249 | | -) |
1250 | 1221 |
|
1251 | 1222 |
|
1252 | | -@pytest.mark.unvectorized |
1253 | 1223 | @pytest.mark.parametrize("func_name, func, case", binary_params) |
1254 | | -@given(x1=x1_strat, x2=x2_strat, data=st.data()) |
1255 | | -def test_binary(func_name, func, case, x1, x2, data): |
1256 | | - result_shape = sh.broadcast_shapes(x1.shape, x2.shape) |
1257 | | - all_indices = list(sh.iter_indices(x1.shape, x2.shape, result_shape)) |
1258 | | - |
1259 | | - indices_strat = st.shared(st.sampled_from(all_indices)) |
1260 | | - set_x1_idx = data.draw(indices_strat.map(lambda t: t[0]), label="set x1 idx") |
1261 | | - set_x1_value = data.draw(case.x1_cond_from_dtype(x1.dtype), label="set x1 value") |
1262 | | - x1[set_x1_idx] = set_x1_value |
1263 | | - note(f"{x1=}") |
1264 | | - set_x2_idx = data.draw(indices_strat.map(lambda t: t[1]), label="set x2 idx") |
1265 | | - set_x2_value = data.draw(case.x2_cond_from_dtype(x2.dtype), label="set x2 value") |
1266 | | - x2[set_x2_idx] = set_x2_value |
1267 | | - note(f"{x2=}") |
1268 | | - |
1269 | | - res = func(x1, x2) |
1270 | | - # sanity check |
1271 | | - ph.assert_result_shape( |
1272 | | - func_name, |
1273 | | - in_shapes=[x1.shape, x2.shape], |
1274 | | - out_shape=res.shape, |
1275 | | - expected=result_shape, |
| 1224 | +@settings(max_examples=1) |
| 1225 | +@given(data=st.data()) |
| 1226 | +def test_binary(func_name, func, case, data): |
| 1227 | + # We don't use example() like in test_unary because the same internal shared |
| 1228 | + # strategies used in both x1's and x2's don't "sync" with example() draws. |
| 1229 | + x1_value = data.draw(case.x1_cond_from_dtype(xp.float64), label="x1_value") |
| 1230 | + x2_value = data.draw(case.x2_cond_from_dtype(xp.float64), label="x2_value") |
| 1231 | + x1 = xp.asarray(x1_value, dtype=xp.float64) |
| 1232 | + x2 = xp.asarray(x2_value, dtype=xp.float64) |
| 1233 | + |
| 1234 | + out = func(x1, x2) |
| 1235 | + out_value = float(out) |
| 1236 | + |
| 1237 | + assert case.check_result(x1_value, x2_value, out_value), ( |
| 1238 | + f"out={out_value}, but should be {case.result_expr} [{func_name}()]\n" |
| 1239 | + f"condition: {case}\n" |
| 1240 | + f"x1={x1_value}, x2={x2_value}" |
1276 | 1241 | ) |
1277 | 1242 |
|
1278 | | - good_example = False |
1279 | | - for l_idx, r_idx, o_idx in all_indices: |
1280 | | - l = float(x1[l_idx]) |
1281 | | - r = float(x2[r_idx]) |
1282 | | - if case.cond(l, r): |
1283 | | - good_example = True |
1284 | | - o = float(res[o_idx]) |
1285 | | - f_left = f"{sh.fmt_idx('x1', l_idx)}={l}" |
1286 | | - f_right = f"{sh.fmt_idx('x2', r_idx)}={r}" |
1287 | | - f_out = f"{sh.fmt_idx('out', o_idx)}={o}" |
1288 | | - assert case.check_result(l, r, o), ( |
1289 | | - f"{f_out}, but should be {case.result_expr} [{func_name}()]\n" |
1290 | | - f"condition: {case}\n" |
1291 | | - f"{f_left}, {f_right}" |
1292 | | - ) |
1293 | | - break |
1294 | | - assume(good_example) |
1295 | 1243 |
|
1296 | 1244 |
|
1297 | | -@pytest.mark.unvectorized |
1298 | 1245 | @pytest.mark.parametrize("iop_name, iop, case", iop_params) |
1299 | | -@given( |
1300 | | - oneway_dtypes=hh.oneway_promotable_dtypes(dh.real_float_dtypes), |
1301 | | - oneway_shapes=hh.oneway_broadcastable_shapes(), |
1302 | | - data=st.data(), |
1303 | | -) |
1304 | | -def test_iop(iop_name, iop, case, oneway_dtypes, oneway_shapes, data): |
1305 | | - x1 = data.draw( |
1306 | | - hh.arrays(dtype=oneway_dtypes.result_dtype, shape=oneway_shapes.result_shape), |
1307 | | - label="x1", |
| 1246 | +@settings(max_examples=1) |
| 1247 | +@given(data=st.data()) |
| 1248 | +def test_iop(iop_name, iop, case, data): |
| 1249 | + # See test_binary comment |
| 1250 | + x1_value = data.draw(case.x1_cond_from_dtype(xp.float64), label="x1_value") |
| 1251 | + x2_value = data.draw(case.x2_cond_from_dtype(xp.float64), label="x2_value") |
| 1252 | + x1 = xp.asarray(x1_value, dtype=xp.float64) |
| 1253 | + x2 = xp.asarray(x2_value, dtype=xp.float64) |
| 1254 | + |
| 1255 | + res = iop(x1, x2) |
| 1256 | + res_value = float(res) |
| 1257 | + |
| 1258 | + assert case.check_result(x1_value, x2_value, res_value), ( |
| 1259 | + f"x1={res}, but should be {case.result_expr} [{func_name}()]\n" |
| 1260 | + f"condition: {case}\n" |
| 1261 | + f"x1={x1_value}, x2={x2_value}" |
1308 | 1262 | ) |
1309 | | - x2 = data.draw( |
1310 | | - hh.arrays(dtype=oneway_dtypes.input_dtype, shape=oneway_shapes.input_shape), |
1311 | | - label="x2", |
1312 | | - ) |
1313 | | - |
1314 | | - all_indices = list(sh.iter_indices(x1.shape, x2.shape, x1.shape)) |
1315 | | - |
1316 | | - indices_strat = st.shared(st.sampled_from(all_indices)) |
1317 | | - set_x1_idx = data.draw(indices_strat.map(lambda t: t[0]), label="set x1 idx") |
1318 | | - set_x1_value = data.draw(case.x1_cond_from_dtype(x1.dtype), label="set x1 value") |
1319 | | - x1[set_x1_idx] = set_x1_value |
1320 | | - note(f"{x1=}") |
1321 | | - set_x2_idx = data.draw(indices_strat.map(lambda t: t[1]), label="set x2 idx") |
1322 | | - set_x2_value = data.draw(case.x2_cond_from_dtype(x2.dtype), label="set x2 value") |
1323 | | - x2[set_x2_idx] = set_x2_value |
1324 | | - note(f"{x2=}") |
1325 | | - |
1326 | | - res = xp.asarray(x1, copy=True) |
1327 | | - res = iop(res, x2) |
1328 | | - # sanity check |
1329 | | - ph.assert_result_shape( |
1330 | | - iop_name, in_shapes=[x1.shape, x2.shape], out_shape=res.shape |
1331 | | - ) |
1332 | | - |
1333 | | - good_example = False |
1334 | | - for l_idx, r_idx, o_idx in all_indices: |
1335 | | - l = float(x1[l_idx]) |
1336 | | - r = float(x2[r_idx]) |
1337 | | - if case.cond(l, r): |
1338 | | - good_example = True |
1339 | | - o = float(res[o_idx]) |
1340 | | - f_left = f"{sh.fmt_idx('x1', l_idx)}={l}" |
1341 | | - f_right = f"{sh.fmt_idx('x2', r_idx)}={r}" |
1342 | | - f_out = f"{sh.fmt_idx('out', o_idx)}={o}" |
1343 | | - assert case.check_result(l, r, o), ( |
1344 | | - f"{f_out}, but should be {case.result_expr} [{iop_name}()]\n" |
1345 | | - f"condition: {case}\n" |
1346 | | - f"{f_left}, {f_right}" |
1347 | | - ) |
1348 | | - break |
1349 | | - assume(good_example) |
1350 | 1263 |
|
1351 | 1264 |
|
1352 | 1265 | @pytest.mark.parametrize( |
|
0 commit comments