|
1 | | -from . import ( |
2 | | - approximators, |
3 | | - adapters, |
4 | | - augmentations, |
5 | | - datasets, |
6 | | - diagnostics, |
7 | | - distributions, |
8 | | - experimental, |
9 | | - networks, |
10 | | - simulators, |
11 | | - utils, |
12 | | - workflows, |
13 | | - wrappers, |
14 | | -) |
15 | | - |
16 | | -from .adapters import Adapter |
17 | | -from .approximators import ContinuousApproximator, PointApproximator |
18 | | -from .datasets import OfflineDataset, OnlineDataset, DiskDataset |
19 | | -from .simulators import make_simulator |
20 | | -from .workflows import BasicWorkflow |
| 1 | +# ruff: noqa: E402 |
| 2 | +# disable E402 to allow for setup code before importing any internals (which could import keras) |
21 | 3 |
|
22 | 4 |
|
23 | 5 | def setup(): |
24 | 6 | # perform any necessary setup without polluting the namespace |
| 7 | + import os |
| 8 | + from importlib.util import find_spec |
| 9 | + |
| 10 | + issue_url = "https://github.com/bayesflow-org/bayesflow/issues/new?template=bug_report.md" |
| 11 | + |
| 12 | + if "KERAS_BACKEND" not in os.environ: |
| 13 | + # check for available backends and automatically set the KERAS_BACKEND env variable or raise an error |
| 14 | + class Backend: |
| 15 | + def __init__(self, display_name, package_name, env_name, install_url, priority): |
| 16 | + self.display_name = display_name |
| 17 | + self.package_name = package_name |
| 18 | + self.env_name = env_name |
| 19 | + self.install_url = install_url |
| 20 | + self.priority = priority |
| 21 | + |
| 22 | + backends = [ |
| 23 | + Backend("JAX", "jax", "jax", "https://docs.jax.dev/en/latest/quickstart.html#installation", 0), |
| 24 | + Backend("PyTorch", "torch", "torch", "https://pytorch.org/get-started/locally/", 1), |
| 25 | + Backend("TensorFlow", "tensorflow", "tensorflow", "https://www.tensorflow.org/install", 2), |
| 26 | + ] |
| 27 | + |
| 28 | + found_backends = [] |
| 29 | + for backend in backends: |
| 30 | + if find_spec(backend.package_name) is not None: |
| 31 | + found_backends.append(backend) |
| 32 | + |
| 33 | + if not found_backends: |
| 34 | + message = "No suitable backend found. Please install one of the following:\n" |
| 35 | + for backend in backends: |
| 36 | + message += f"{backend.display_name}\n" |
| 37 | + message += "\n" |
| 38 | + |
| 39 | + message += f"If you continue to see this error, please file a bug report at {issue_url}.\n" |
| 40 | + message += ( |
| 41 | + "You can manually select a backend by setting the KERAS_BACKEND environment variable as shown below:\n" |
| 42 | + ) |
| 43 | + message += "https://keras.io/getting_started/#configuring-your-backend" |
| 44 | + |
| 45 | + raise ImportError(message) |
| 46 | + |
| 47 | + if len(found_backends) > 1: |
| 48 | + import warnings |
| 49 | + |
| 50 | + found_backends.sort(key=lambda b: b.priority) |
| 51 | + chosen_backend = found_backends[0] |
| 52 | + |
| 53 | + warnings.warn( |
| 54 | + f"Multiple Keras-compatible backends detected ({', '.join(b.display_name for b in found_backends)}).\n" |
| 55 | + f"Defaulting to {chosen_backend.display_name}.\n" |
| 56 | + "To override, set the KERAS_BACKEND environment variable before importing bayesflow.\n" |
| 57 | + "See: https://keras.io/getting_started/#configuring-your-backend" |
| 58 | + ) |
| 59 | + else: |
| 60 | + os.environ["KERAS_BACKEND"] = found_backends[0].env_name |
| 61 | + |
25 | 62 | import keras |
26 | 63 | import logging |
27 | 64 |
|
@@ -60,3 +97,24 @@ def setup(): |
60 | 97 | # call and clean up namespace |
61 | 98 | setup() |
62 | 99 | del setup |
| 100 | + |
| 101 | +from . import ( |
| 102 | + approximators, |
| 103 | + adapters, |
| 104 | + augmentations, |
| 105 | + datasets, |
| 106 | + diagnostics, |
| 107 | + distributions, |
| 108 | + experimental, |
| 109 | + networks, |
| 110 | + simulators, |
| 111 | + utils, |
| 112 | + workflows, |
| 113 | + wrappers, |
| 114 | +) |
| 115 | + |
| 116 | +from .adapters import Adapter |
| 117 | +from .approximators import ContinuousApproximator, PointApproximator |
| 118 | +from .datasets import OfflineDataset, OnlineDataset, DiskDataset |
| 119 | +from .simulators import make_simulator |
| 120 | +from .workflows import BasicWorkflow |
0 commit comments