|
| 1 | +import os |
1 | 2 | from functools import wraps |
2 | | -from os import getenv |
| 3 | +from importlib import import_module |
3 | 4 |
|
4 | 5 | from hypothesis import strategies as st |
5 | 6 | from hypothesis.extra import array_api |
6 | 7 |
|
7 | 8 | from . import _version |
8 | | -from ._array_module import mod as _xp |
9 | 9 |
|
10 | | -__all__ = ["api_version", "xps"] |
| 10 | +__all__ = ["xp", "api_version", "xps"] |
| 11 | + |
| 12 | + |
| 13 | +# You can comment the following out and instead import the specific array module |
| 14 | +# you want to test, e.g. `import numpy.array_api as xp`. |
| 15 | +if "ARRAY_API_TESTS_MODULE" in os.environ: |
| 16 | + xp_name = os.environ["ARRAY_API_TESTS_MODULE"] |
| 17 | + _module, _sub = xp_name, None |
| 18 | + if "." in xp_name: |
| 19 | + _module, _sub = xp_name.split(".", 1) |
| 20 | + xp = import_module(_module) |
| 21 | + if _sub: |
| 22 | + try: |
| 23 | + xp = getattr(xp, _sub) |
| 24 | + except AttributeError: |
| 25 | + # _sub may be a submodule that needs to be imported. WE can't |
| 26 | + # do this in every case because some array modules are not |
| 27 | + # submodules that can be imported (like mxnet.nd). |
| 28 | + xp = import_module(xp_name) |
| 29 | +else: |
| 30 | + raise RuntimeError( |
| 31 | + "No array module specified - either edit __init__.py or set the " |
| 32 | + "ARRAY_API_TESTS_MODULE environment variable." |
| 33 | + ) |
11 | 34 |
|
12 | 35 |
|
13 | 36 | # We monkey patch floats() to always disable subnormals as they are out-of-scope |
@@ -43,9 +66,9 @@ def _from_dtype(*a, **kw): |
43 | 66 | pass |
44 | 67 |
|
45 | 68 |
|
46 | | -api_version = getenv( |
47 | | - "ARRAY_API_TESTS_VERSION", getattr(_xp, "__array_api_version__", "2021.12") |
| 69 | +api_version = os.getenv( |
| 70 | + "ARRAY_API_TESTS_VERSION", getattr(xp, "__array_api_version__", "2021.12") |
48 | 71 | ) |
49 | | -xps = array_api.make_strategies_namespace(_xp, api_version=api_version) |
| 72 | +xps = array_api.make_strategies_namespace(xp, api_version=api_version) |
50 | 73 |
|
51 | 74 | __version__ = _version.get_versions()["version"] |
0 commit comments