Skip to content

Commit 5d3594c

Browse files
committed
Merge branch 'dev' of https://github.com/bayesflow-org/bayesflow into dev [skip ci]
2 parents 0ea79d7 + 706e3fd commit 5d3594c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+1182
-1058
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,8 @@ Many examples from [Bayesian Cognitive Modeling: A Practical Course](https://bay
135135
6. [Bayesian experimental design](examples/Bayesian_Experimental_Design.ipynb)
136136
7. [Simple model comparison example](examples/One_Sample_TTest.ipynb)
137137
8. [Likelihood estimation](examples/Likelihood_Estimation.ipynb)
138-
9. [Moving from BayesFlow v1.1 to v2.0](examples/From_BayesFlow_1.1_to_2.0.ipynb)
138+
9. [Multimodal data](examples/Multimodal_Data.ipynb)
139+
10. [Moving from BayesFlow v1.1 to v2.0](examples/From_BayesFlow_1.1_to_2.0.ipynb)
139140

140141
More tutorials are always welcome! Please consider making a pull request if you have a cool application that you want to contribute.
141142

bayesflow/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from . import (
22
approximators,
33
adapters,
4+
augmentations,
45
datasets,
56
diagnostics,
67
distributions,

bayesflow/adapters/adapter.py

Lines changed: 22 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from collections.abc import Callable, MutableSequence, Sequence, Mapping
1+
from collections.abc import Callable, MutableSequence, Sequence
22

33
import numpy as np
44

@@ -18,7 +18,6 @@
1818
Keep,
1919
Log,
2020
MapTransform,
21-
NNPE,
2221
NumpyTransform,
2322
OneHot,
2423
Rename,
@@ -87,16 +86,14 @@ def get_config(self) -> dict:
8786
return serialize(config)
8887

8988
def forward(
90-
self, data: dict[str, any], *, stage: str = "inference", log_det_jac: bool = False, **kwargs
89+
self, data: dict[str, any], *, log_det_jac: bool = False, **kwargs
9190
) -> dict[str, np.ndarray] | tuple[dict[str, np.ndarray], dict[str, np.ndarray]]:
9291
"""Apply the transforms in the forward direction.
9392
9493
Parameters
9594
----------
96-
data : dict
95+
data : dict[str, any]
9796
The data to be transformed.
98-
stage : str, one of ["training", "validation", "inference"]
99-
The stage the function is called in.
10097
log_det_jac: bool, optional
10198
Whether to return the log determinant of the Jacobian of the transforms.
10299
**kwargs : dict
@@ -110,28 +107,26 @@ def forward(
110107
data = data.copy()
111108
if not log_det_jac:
112109
for transform in self.transforms:
113-
data = transform(data, stage=stage, **kwargs)
110+
data = transform(data, **kwargs)
114111
return data
115112

116113
log_det_jac = {}
117114
for transform in self.transforms:
118-
transformed_data = transform(data, stage=stage, **kwargs)
115+
transformed_data = transform(data, **kwargs)
119116
log_det_jac = transform.log_det_jac(data, log_det_jac, **kwargs)
120117
data = transformed_data
121118

122119
return data, log_det_jac
123120

124121
def inverse(
125-
self, data: dict[str, np.ndarray], *, stage: str = "inference", log_det_jac: bool = False, **kwargs
122+
self, data: dict[str, any], *, log_det_jac: bool = False, **kwargs
126123
) -> dict[str, np.ndarray] | tuple[dict[str, np.ndarray], dict[str, np.ndarray]]:
127124
"""Apply the transforms in the inverse direction.
128125
129126
Parameters
130127
----------
131-
data : dict
128+
data : dict[str, any]
132129
The data to be transformed.
133-
stage : str, one of ["training", "validation", "inference"]
134-
The stage the function is called in.
135130
log_det_jac: bool, optional
136131
Whether to return the log determinant of the Jacobian of the transforms.
137132
**kwargs : dict
@@ -145,18 +140,18 @@ def inverse(
145140
data = data.copy()
146141
if not log_det_jac:
147142
for transform in reversed(self.transforms):
148-
data = transform(data, stage=stage, inverse=True, **kwargs)
143+
data = transform(data, inverse=True, **kwargs)
149144
return data
150145

151146
log_det_jac = {}
152147
for transform in reversed(self.transforms):
153-
data = transform(data, stage=stage, inverse=True, **kwargs)
148+
data = transform(data, inverse=True, **kwargs)
154149
log_det_jac = transform.log_det_jac(data, log_det_jac, inverse=True, **kwargs)
155150

156151
return data, log_det_jac
157152

158153
def __call__(
159-
self, data: Mapping[str, any], *, inverse: bool = False, stage="inference", **kwargs
154+
self, data: dict[str, any], *, inverse: bool = False, **kwargs
160155
) -> dict[str, np.ndarray] | tuple[dict[str, np.ndarray], dict[str, np.ndarray]]:
161156
"""Apply the transforms in the given direction.
162157
@@ -166,8 +161,6 @@ def __call__(
166161
The data to be transformed.
167162
inverse : bool, optional
168163
If False, apply the forward transform, else apply the inverse transform (default False).
169-
stage : str, one of ["training", "validation", "inference"]
170-
The stage the function is called in.
171164
**kwargs
172165
Additional keyword arguments passed to each transform.
173166
@@ -177,9 +170,9 @@ def __call__(
177170
The transformed data or tuple of transformed data and log determinant of the Jacobian.
178171
"""
179172
if inverse:
180-
return self.inverse(data, stage=stage, **kwargs)
173+
return self.inverse(data, **kwargs)
181174

182-
return self.forward(data, stage=stage, **kwargs)
175+
return self.forward(data, **kwargs)
183176

184177
def __repr__(self):
185178
result = ""
@@ -701,43 +694,6 @@ def map_dtype(self, keys: str | Sequence[str], to_dtype: str):
701694
self.transforms.append(transform)
702695
return self
703696

704-
def nnpe(
705-
self,
706-
keys: str | Sequence[str],
707-
*,
708-
spike_scale: float | None = None,
709-
slab_scale: float | None = None,
710-
per_dimension: bool = True,
711-
seed: int | None = None,
712-
):
713-
"""Append an :py:class:`~transforms.NNPE` transform to the adapter.
714-
715-
Parameters
716-
----------
717-
keys : str or Sequence of str
718-
The names of the variables to transform.
719-
spike_scale : float or np.ndarray or None, default=None
720-
The scale of the spike (Normal) distribution. Automatically determined if None.
721-
slab_scale : float or np.ndarray or None, default=None
722-
The scale of the slab (Cauchy) distribution. Automatically determined if None.
723-
per_dimension : bool, default=True
724-
If true, noise is applied per dimension of the last axis of the input data.
725-
If false, noise is applied globally.
726-
seed : int or None
727-
The seed for the random number generator. If None, a random seed is used.
728-
"""
729-
if isinstance(keys, str):
730-
keys = [keys]
731-
732-
transform = MapTransform(
733-
{
734-
key: NNPE(spike_scale=spike_scale, slab_scale=slab_scale, per_dimension=per_dimension, seed=seed)
735-
for key in keys
736-
}
737-
)
738-
self.transforms.append(transform)
739-
return self
740-
741697
def one_hot(self, keys: str | Sequence[str], num_classes: int):
742698
"""Append a :py:class:`~transforms.OneHot` transform to the adapter.
743699
@@ -857,6 +813,8 @@ def standardize(
857813
self,
858814
include: str | Sequence[str] = None,
859815
*,
816+
mean: int | float | np.ndarray,
817+
std: int | float | np.ndarray,
860818
predicate: Predicate = None,
861819
exclude: str | Sequence[str] = None,
862820
**kwargs,
@@ -865,10 +823,14 @@ def standardize(
865823
866824
Parameters
867825
----------
868-
predicate : Predicate, optional
869-
Function that indicates which variables should be transformed.
870826
include : str or Sequence of str, optional
871827
Names of variables to include in the transform.
828+
mean : int or float
829+
Specifies the mean (location) of the transform.
830+
std : int or float
831+
Specifies the standard deviation (scale) of the transform.
832+
predicate : Predicate, optional
833+
Function that indicates which variables should be transformed.
872834
exclude : str or Sequence of str, optional
873835
Names of variables to exclude from the transform.
874836
**kwargs :
@@ -879,6 +841,8 @@ def standardize(
879841
predicate=predicate,
880842
include=include,
881843
exclude=exclude,
844+
mean=mean,
845+
std=std,
882846
**kwargs,
883847
)
884848
self.transforms.append(transform)

bayesflow/adapters/transforms/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from .keep import Keep
1313
from .log import Log
1414
from .map_transform import MapTransform
15-
from .nnpe import NNPE
1615
from .numpy_transform import NumpyTransform
1716
from .one_hot import OneHot
1817
from .rename import Rename

bayesflow/adapters/transforms/standardize.py

Lines changed: 13 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
from collections.abc import Sequence
2-
import warnings
3-
41
import numpy as np
52

63
from bayesflow.utils.serialization import serializable, serialize
@@ -11,120 +8,50 @@
118
@serializable("bayesflow.adapters")
129
class Standardize(ElementwiseTransform):
1310
"""
14-
Transform that when applied standardizes data using typical z-score standardization
15-
i.e. for some unstandardized data x the standardized version z would be
11+
Transform that when applied standardizes data using typical z-score standardization with
12+
fixed means and std, i.e. for some unstandardized data x the standardized version z would be
1613
1714
>>> z = (x - mean(x)) / std(x)
1815
16+
Important: Ensure dynamic standardization (employed by BayesFlow approximators) has been
17+
turned off when using this transform.
18+
1919
Parameters
2020
----------
21-
mean : int or float, optional
22-
Specify a mean if known but will be estimated from data when not provided
23-
std : int or float, optional
24-
Specify a standard devation if known but will be estimated from data when not provided
25-
axis : int, optional
26-
A specific axis along which standardization should take place. By default
27-
standardization happens individually for each dimension
28-
momentum : float in (0,1)
29-
The momentum during training
21+
mean : int or float
22+
Specifies the mean (location) of the transform.
23+
std : int or float
24+
Specifies the standard deviation (scale) of the transform.
3025
3126
Examples
3227
--------
33-
1) Standardize all variables using their individually estimated mean and stds.
34-
35-
>>> adapter = (
36-
bf.adapters.Adapter()
37-
.standardize()
38-
)
39-
40-
41-
2) Standardize all with same known mean and std.
42-
43-
>>> adapter = (
44-
bf.adapters.Adapter()
45-
.standardize(mean = 5, sd = 10)
46-
)
47-
48-
49-
3) Mix of fixed and estimated means/stds. Suppose we have priors for "beta" and "sigma" where we
50-
know the means and stds. However for all other variables, the means and stds are unknown.
51-
Then standardize should be used in several stages specifying which variables to include or exclude.
52-
53-
>>> adapter = (
54-
bf.adapters.Adapter()
55-
# mean fixed, std estimated
56-
.standardize(include = "beta", mean = 1)
57-
# both mean and SD fixed
58-
.standardize(include = "sigma", mean = 0.6, sd = 3)
59-
# both means and stds estimated for all other variables
60-
.standardize(exclude = ["beta", "sigma"])
61-
)
28+
>>> adapter = bf.Adapter().standardize(include="beta", mean=5, std=10)
6229
"""
6330

6431
def __init__(
6532
self,
66-
mean: int | float | np.ndarray = None,
67-
std: int | float | np.ndarray = None,
68-
axis: int | Sequence[int] = None,
69-
momentum: float | None = 0.99,
33+
mean: int | float | np.ndarray,
34+
std: int | float | np.ndarray,
7035
):
7136
super().__init__()
7237

73-
if mean is None or std is None:
74-
warnings.warn(
75-
"Dynamic standardization is deprecated and will be removed in later versions."
76-
"Instead, use the standardize argument of the approximator / workflow instance or provide "
77-
"fixed mean and std arguments. You may incur some redundant computations if you keep this transform.",
78-
FutureWarning,
79-
)
80-
8138
self.mean = mean
8239
self.std = std
8340

84-
if isinstance(axis, Sequence):
85-
# numpy hates lists
86-
axis = tuple(axis)
87-
self.axis = axis
88-
self.momentum = momentum
89-
9041
def get_config(self) -> dict:
9142
config = {
9243
"mean": self.mean,
9344
"std": self.std,
94-
"axis": self.axis,
95-
"momentum": self.momentum,
9645
}
9746
return serialize(config)
9847

99-
def forward(self, data: np.ndarray, stage: str = "inference", **kwargs) -> np.ndarray:
100-
if self.axis is None:
101-
self.axis = tuple(range(data.ndim - 1))
102-
103-
if self.mean is None:
104-
self.mean = np.mean(data, axis=self.axis, keepdims=True)
105-
else:
106-
if self.momentum is not None and stage == "training":
107-
self.mean = self.momentum * self.mean + (1.0 - self.momentum) * np.mean(
108-
data, axis=self.axis, keepdims=True
109-
)
110-
111-
if self.std is None:
112-
self.std = np.std(data, axis=self.axis, keepdims=True, ddof=1)
113-
else:
114-
if self.momentum is not None and stage == "training":
115-
self.std = self.momentum * self.std + (1.0 - self.momentum) * np.std(
116-
data, axis=self.axis, keepdims=True, ddof=1
117-
)
118-
48+
def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
11949
mean = np.broadcast_to(self.mean, data.shape)
12050
std = np.broadcast_to(self.std, data.shape)
12151

12252
return (data - mean) / std
12353

12454
def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
125-
if self.mean is None or self.std is None:
126-
raise RuntimeError("Cannot call `inverse` before calling `forward` at least once.")
127-
12855
mean = np.broadcast_to(self.mean, data.shape)
12956
std = np.broadcast_to(self.std, data.shape)
13057

bayesflow/approximators/continuous_approximator.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import numpy as np
44

55
import keras
6-
import warnings
76

87
from bayesflow.adapters import Adapter
98
from bayesflow.networks import InferenceNetwork, SummaryNetwork
@@ -476,7 +475,7 @@ def _prepare_data(
476475
Handles inputs containing only conditions, only inference_variables, or both.
477476
Optionally tracks log-determinant Jacobian (ldj) of transformations.
478477
"""
479-
adapted = self.adapter(data, strict=False, stage="inference", log_det_jac=log_det_jac, **kwargs)
478+
adapted = self.adapter(data, strict=False, log_det_jac=log_det_jac, **kwargs)
480479

481480
if log_det_jac:
482481
data, ldj = adapted
@@ -565,7 +564,7 @@ def summarize(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray:
565564
if self.summary_network is None:
566565
raise ValueError("A summary network is required to compute summaries.")
567566

568-
data_adapted = self.adapter(data, strict=False, stage="inference", **kwargs)
567+
data_adapted = self.adapter(data, strict=False, **kwargs)
569568
if "summary_variables" not in data_adapted or data_adapted["summary_variables"] is None:
570569
raise ValueError("Summary variables are required to compute summaries.")
571570

@@ -575,14 +574,6 @@ def summarize(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray:
575574

576575
return summaries
577576

578-
def summaries(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray:
579-
"""
580-
.. deprecated:: 2.0.4
581-
`summaries` will be removed in version 2.0.5, it was renamed to `summarize` which should be used instead.
582-
"""
583-
warnings.warn("`summaries` was renamed to `summarize` and will be removed in version 2.0.5.", FutureWarning)
584-
return self.summarize(data=data, **kwargs)
585-
586577
def log_prob(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray:
587578
"""
588579
Computes the log-probability of given data under the model. The `data` dictionary is preprocessed using the

0 commit comments

Comments
 (0)