|
1 | 1 | import importlib |
2 | | -import itertools |
3 | 2 | import operator |
4 | 3 | import os |
5 | 4 |
|
@@ -28,41 +27,48 @@ def elemwise_args(request, rng, max_size): |
28 | 27 | return s1_sps, s2_sps |
29 | 28 |
|
30 | 29 |
|
31 | | -def get_elemwise_id(param): |
32 | | - f, backend = param |
33 | | - return f"{f=}-{backend=}" |
| 30 | +@pytest.fixture(params=[operator.add, operator.mul, operator.gt]) |
| 31 | +def elemwise_function(request): |
| 32 | + return request.param |
34 | 33 |
|
35 | 34 |
|
36 | | -@pytest.fixture( |
37 | | - params=itertools.product([operator.add, operator.mul, operator.gt], ["SciPy", "Numba", "Finch"]), |
38 | | - scope="function", |
39 | | - ids=get_elemwise_id, |
40 | | -) |
41 | | -def backend(request): |
42 | | - f, backend = request.param |
43 | | - os.environ[sparse._ENV_VAR_NAME] = backend |
| 35 | +@pytest.fixture(params=["SciPy", "Numba", "Finch"]) |
| 36 | +def backend_name(request): |
| 37 | + return request.param |
| 38 | + |
| 39 | + |
| 40 | +@pytest.fixture |
| 41 | +def backend_setup(backend_name): |
| 42 | + os.environ[sparse._ENV_VAR_NAME] = backend_name |
44 | 43 | importlib.reload(sparse) |
45 | | - yield f, sparse, backend |
| 44 | + yield sparse, backend_name |
46 | 45 | del os.environ[sparse._ENV_VAR_NAME] |
47 | 46 | importlib.reload(sparse) |
48 | 47 |
|
49 | 48 |
|
50 | | -def test_elemwise(benchmark, backend, elemwise_args): |
| 49 | +@pytest.fixture |
| 50 | +def sparse_arrays(elemwise_args, backend_setup): |
51 | 51 | s1_sps, s2_sps = elemwise_args |
52 | | - f, sparse, backend = backend |
| 52 | + sparse, backend_name = backend_setup |
53 | 53 |
|
54 | | - if backend == "SciPy": |
| 54 | + if backend_name == "SciPy": |
55 | 55 | s1 = s1_sps |
56 | 56 | s2 = s2_sps |
57 | | - elif backend == "Numba": |
| 57 | + elif backend_name == "Numba": |
58 | 58 | s1 = sparse.asarray(s1_sps) |
59 | 59 | s2 = sparse.asarray(s2_sps) |
60 | | - elif backend == "Finch": |
| 60 | + elif backend_name == "Finch": |
61 | 61 | s1 = sparse.asarray(s1_sps.asformat("csc"), format="csc") |
62 | 62 | s2 = sparse.asarray(s2_sps.asformat("csc"), format="csc") |
63 | 63 |
|
64 | | - f(s1, s2) |
| 64 | + return s1, s2 |
| 65 | + |
| 66 | + |
| 67 | +def test_elemwise(benchmark, elemwise_function, sparse_arrays): |
| 68 | + s1, s2 = sparse_arrays |
| 69 | + |
| 70 | + elemwise_function(s1, s2) |
65 | 71 |
|
66 | 72 | @benchmark |
67 | 73 | def bench(): |
68 | | - f(s1, s2) |
| 74 | + elemwise_function(s1, s2) |
0 commit comments