Skip to content

Commit 8afff13

Browse files
LarsKueCopilotstefanradev93
authored
Auto-select backend (#543)
* add automatic backend detection and selection * Fix typo Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Add priority ordering of backends --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: stefanradev93 <stefan.radev93@gmail.com>
1 parent d9e9782 commit 8afff13

File tree

1 file changed

+78
-20
lines changed

1 file changed

+78
-20
lines changed

bayesflow/__init__.py

Lines changed: 78 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,64 @@
1-
from . import (
2-
approximators,
3-
adapters,
4-
augmentations,
5-
datasets,
6-
diagnostics,
7-
distributions,
8-
experimental,
9-
networks,
10-
simulators,
11-
utils,
12-
workflows,
13-
wrappers,
14-
)
15-
16-
from .adapters import Adapter
17-
from .approximators import ContinuousApproximator, PointApproximator
18-
from .datasets import OfflineDataset, OnlineDataset, DiskDataset
19-
from .simulators import make_simulator
20-
from .workflows import BasicWorkflow
1+
# ruff: noqa: E402
2+
# disable E402 to allow for setup code before importing any internals (which could import keras)
213

224

235
def setup():
246
# perform any necessary setup without polluting the namespace
7+
import os
8+
from importlib.util import find_spec
9+
10+
issue_url = "https://github.com/bayesflow-org/bayesflow/issues/new?template=bug_report.md"
11+
12+
if "KERAS_BACKEND" not in os.environ:
13+
# check for available backends and automatically set the KERAS_BACKEND env variable or raise an error
14+
class Backend:
15+
def __init__(self, display_name, package_name, env_name, install_url, priority):
16+
self.display_name = display_name
17+
self.package_name = package_name
18+
self.env_name = env_name
19+
self.install_url = install_url
20+
self.priority = priority
21+
22+
backends = [
23+
Backend("JAX", "jax", "jax", "https://docs.jax.dev/en/latest/quickstart.html#installation", 0),
24+
Backend("PyTorch", "torch", "torch", "https://pytorch.org/get-started/locally/", 1),
25+
Backend("TensorFlow", "tensorflow", "tensorflow", "https://www.tensorflow.org/install", 2),
26+
]
27+
28+
found_backends = []
29+
for backend in backends:
30+
if find_spec(backend.package_name) is not None:
31+
found_backends.append(backend)
32+
33+
if not found_backends:
34+
message = "No suitable backend found. Please install one of the following:\n"
35+
for backend in backends:
36+
message += f"{backend.display_name}\n"
37+
message += "\n"
38+
39+
message += f"If you continue to see this error, please file a bug report at {issue_url}.\n"
40+
message += (
41+
"You can manually select a backend by setting the KERAS_BACKEND environment variable as shown below:\n"
42+
)
43+
message += "https://keras.io/getting_started/#configuring-your-backend"
44+
45+
raise ImportError(message)
46+
47+
if len(found_backends) > 1:
48+
import warnings
49+
50+
found_backends.sort(key=lambda b: b.priority)
51+
chosen_backend = found_backends[0]
52+
53+
warnings.warn(
54+
f"Multiple Keras-compatible backends detected ({', '.join(b.display_name for b in found_backends)}).\n"
55+
f"Defaulting to {chosen_backend.display_name}.\n"
56+
"To override, set the KERAS_BACKEND environment variable before importing bayesflow.\n"
57+
"See: https://keras.io/getting_started/#configuring-your-backend"
58+
)
59+
else:
60+
os.environ["KERAS_BACKEND"] = found_backends[0].env_name
61+
2562
import keras
2663
import logging
2764

@@ -60,3 +97,24 @@ def setup():
6097
# call and clean up namespace
6198
setup()
6299
del setup
100+
101+
from . import (
102+
approximators,
103+
adapters,
104+
augmentations,
105+
datasets,
106+
diagnostics,
107+
distributions,
108+
experimental,
109+
networks,
110+
simulators,
111+
utils,
112+
workflows,
113+
wrappers,
114+
)
115+
116+
from .adapters import Adapter
117+
from .approximators import ContinuousApproximator, PointApproximator
118+
from .datasets import OfflineDataset, OnlineDataset, DiskDataset
119+
from .simulators import make_simulator
120+
from .workflows import BasicWorkflow

0 commit comments

Comments
 (0)