Skip to content

Commit fef8b90

Browse files
committed
[no ci] Merge remote-tracking branch 'upstream/dev' into docs-user-guide
2 parents 6ecd258 + e2c8304 commit fef8b90

27 files changed

+1522
-158
lines changed

bayesflow/adapters/transforms/standardize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
123123

124124
def log_det_jac(self, data, inverse: bool = False, **kwargs) -> np.ndarray:
125125
std = np.broadcast_to(self.std, data.shape)
126-
ldj = np.log(np.abs(std))
126+
ldj = -np.log(np.abs(std))
127127
if inverse:
128128
ldj = -ldj
129129
return np.sum(ldj, axis=tuple(range(1, ldj.ndim)))

bayesflow/approximators/continuous_approximator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,7 @@ def log_prob(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray | dic
458458
# change of variables formula
459459
log_det_jac = log_det_jac.get("inference_variables")
460460
if log_det_jac is not None:
461-
log_prob = log_prob + log_det_jac
461+
log_prob = keras.tree.map_structure(lambda x: x + log_det_jac, log_prob)
462462

463463
return log_prob
464464

bayesflow/datasets/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from .offline_dataset import OfflineDataset
88
from .online_dataset import OnlineDataset
99
from .disk_dataset import DiskDataset
10-
from .rounds_dataset import RoundsDataset
1110

1211
from ..utils._docs import _add_imports_to_all
1312

bayesflow/datasets/disk_dataset.py

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
1-
import keras
2-
import numpy as np
1+
from collections.abc import Mapping, Callable
2+
33
import os
44
import pathlib as pl
55

6+
import numpy as np
7+
8+
import keras
9+
610
from bayesflow.adapters import Adapter
711
from bayesflow.utils import tree_stack, pickle_load
812

@@ -29,11 +33,43 @@ def __init__(
2933
*,
3034
pattern: str = "*.pkl",
3135
batch_size: int,
32-
load_fn: callable = None,
36+
load_fn: Callable = None,
3337
adapter: Adapter | None,
3438
stage: str = "training",
39+
augmentations: Mapping[str, Callable] | Callable = None,
3540
**kwargs,
3641
):
42+
"""
43+
Initialize a DiskDataset instance for offline training using a set of simulations that
44+
do not fit on disk.
45+
46+
Parameters
47+
----------
48+
root : os.PathLike
49+
Root directory containing the sample files.
50+
pattern : str, default="*.pkl"
51+
Glob pattern to match sample files.
52+
batch_size : int
53+
Number of samples per batch.
54+
load_fn : Callable, optional
55+
Function to load a single file into a sample. Defaults to `pickle_load`.
56+
adapter : Adapter or None
57+
Optional adapter to transform the loaded batch.
58+
stage : str, default="training"
59+
Current stage (e.g., "training", "validation", etc.) used by the adapter.
60+
augmentations : dict of str to Callable or Callable, optional
61+
Dictionary of augmentation functions to apply to each corresponding key in the batch
62+
or a function to apply to the entire batch (possibly adding new keys).
63+
64+
If you provide a dictionary of functions, each function should accept one element
65+
of your output batch and return the corresponding transformed element. Otherwise,
66+
your function should accept the entire dictionary output and return a dictionary.
67+
68+
Note - augmentations are applied before the adapter is called and are generally
69+
transforms that you only want to apply during training.
70+
**kwargs
71+
Additional keyword arguments passed to the base `PyDataset`.
72+
"""
3773
super().__init__(**kwargs)
3874
self.batch_size = batch_size
3975
self.root = pl.Path(root)
@@ -42,6 +78,8 @@ def __init__(
4278
self.files = list(map(str, self.root.glob(pattern)))
4379
self.stage = stage
4480

81+
self.augmentations = augmentations
82+
4583
self.shuffle()
4684

4785
def __getitem__(self, item) -> dict[str, np.ndarray]:
@@ -50,12 +88,20 @@ def __getitem__(self, item) -> dict[str, np.ndarray]:
5088

5189
files = self.files[item * self.batch_size : (item + 1) * self.batch_size]
5290

53-
batch = []
54-
for file in files:
55-
batch.append(self.load_fn(file))
91+
batch = [self.load_fn(file) for file in files]
5692

5793
batch = tree_stack(batch)
5894

95+
if self.augmentations is None:
96+
pass
97+
elif isinstance(self.augmentations, Mapping):
98+
for key, fn in self.augmentations.items():
99+
batch[key] = fn(batch[key])
100+
elif isinstance(self.augmentations, Callable):
101+
batch = self.augmentations(batch)
102+
else:
103+
raise RuntimeError(f"Could not apply augmentations of type {type(self.augmentations)}.")
104+
59105
if self.adapter is not None:
60106
batch = self.adapter(batch, stage=self.stage)
61107

bayesflow/datasets/offline_dataset.py

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from collections.abc import Mapping
1+
from collections.abc import Mapping, Callable
22

33
import numpy as np
44

@@ -23,8 +23,37 @@ def __init__(
2323
num_samples: int = None,
2424
*,
2525
stage: str = "training",
26+
augmentations: Mapping[str, Callable] | Callable = None,
2627
**kwargs,
2728
):
29+
"""
30+
Initialize an OfflineDataset instance for offline training with optional data augmentations.
31+
32+
Parameters
33+
----------
34+
data : Mapping[str, np.ndarray]
35+
Pre-simulated data stored in a dictionary, where each key maps to a NumPy array.
36+
batch_size : int
37+
Number of samples per batch.
38+
adapter : Adapter or None
39+
Optional adapter to transform the batch.
40+
num_samples : int, optional
41+
Number of samples in the dataset. If None, it will be inferred from the data.
42+
stage : str, default="training"
43+
Current stage (e.g., "training", "validation", etc.) used by the adapter.
44+
augmentations : dict of str to Callable or Callable, optional
45+
Dictionary of augmentation functions to apply to each corresponding key in the batch
46+
or a function to apply to the entire batch (possibly adding new keys).
47+
48+
If you provide a dictionary of functions, each function should accept one element
49+
of your output batch and return the corresponding transformed element. Otherwise,
50+
your function should accept the entire dictionary output and return a dictionary.
51+
52+
Note - augmentations are applied before the adapter is called and are generally
53+
transforms that you only want to apply during training.
54+
**kwargs
55+
Additional keyword arguments passed to the base `PyDataset`.
56+
"""
2857
super().__init__(**kwargs)
2958
self.batch_size = batch_size
3059
self.data = data
@@ -39,10 +68,29 @@ def __init__(
3968

4069
self.indices = np.arange(self.num_samples, dtype="int64")
4170

71+
self.augmentations = augmentations
72+
4273
self.shuffle()
4374

4475
def __getitem__(self, item: int) -> dict[str, np.ndarray]:
45-
"""Get a batch of pre-simulated data"""
76+
"""
77+
Load a batch of data from disk.
78+
79+
Parameters
80+
----------
81+
item : int
82+
Index of the batch to retrieve.
83+
84+
Returns
85+
-------
86+
dict of str to np.ndarray
87+
A batch of loaded (and optionally augmented/adapted) data.
88+
89+
Raises
90+
------
91+
IndexError
92+
If the requested batch index is out of range.
93+
"""
4694
if not 0 <= item < self.num_batches:
4795
raise IndexError(f"Index {item} is out of bounds for dataset with {self.num_batches} batches.")
4896

@@ -54,6 +102,16 @@ def __getitem__(self, item: int) -> dict[str, np.ndarray]:
54102
for key, value in self.data.items()
55103
}
56104

105+
if self.augmentations is None:
106+
pass
107+
elif isinstance(self.augmentations, Mapping):
108+
for key, fn in self.augmentations.items():
109+
batch[key] = fn(batch[key])
110+
elif isinstance(self.augmentations, Callable):
111+
batch = self.augmentations(batch)
112+
else:
113+
raise RuntimeError(f"Could not apply augmentations of type {type(self.augmentations)}.")
114+
57115
if self.adapter is not None:
58116
batch = self.adapter(batch, stage=self.stage)
59117

bayesflow/datasets/online_dataset.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from collections.abc import Mapping, Callable
2+
13
import keras
24
import numpy as np
35

@@ -7,7 +9,7 @@
79

810
class OnlineDataset(keras.utils.PyDataset):
911
"""
10-
A dataset that is generated on-the-fly.
12+
A dataset that generates simulations on-the-fly.
1113
"""
1214

1315
def __init__(
@@ -18,19 +20,72 @@ def __init__(
1820
adapter: Adapter | None,
1921
*,
2022
stage: str = "training",
23+
augmentations: Mapping[str, Callable] | Callable = None,
2124
**kwargs,
2225
):
26+
"""
27+
Initialize an OnlineDataset instance for infinite stream training.
28+
29+
Parameters
30+
----------
31+
simulator : Simulator
32+
A simulator object with a `.sample(batch_shape)` method to generate data.
33+
batch_size : int
34+
Number of samples per batch.
35+
num_batches : int
36+
Total number of batches in the dataset.
37+
adapter : Adapter or None
38+
Optional adapter to transform the simulated batch.
39+
stage : str, default="training"
40+
Current stage (e.g., "training", "validation", etc.) used by the adapter.
41+
augmentations : dict of str to Callable or Callable, optional
42+
Dictionary of augmentation functions to apply to each corresponding key in the batch
43+
or a function to apply to the entire batch (possibly adding new keys).
44+
45+
If you provide a dictionary of functions, each function should accept one element
46+
of your output batch and return the corresponding transformed element. Otherwise,
47+
your function should accept the entire dictionary output and return a dictionary.
48+
49+
Note - augmentations are applied before the adapter is called and are generally
50+
transforms that you only want to apply during training.
51+
**kwargs
52+
Additional keyword arguments passed to the base `PyDataset`.
53+
"""
2354
super().__init__(**kwargs)
2455

2556
self.batch_size = batch_size
2657
self._num_batches = num_batches
2758
self.adapter = adapter
2859
self.simulator = simulator
2960
self.stage = stage
61+
self.augmentations = augmentations
3062

3163
def __getitem__(self, item: int) -> dict[str, np.ndarray]:
64+
"""
65+
Generate one batch of data.
66+
67+
Parameters
68+
----------
69+
item : int
70+
Index of the batch. Required by signature, but not used.
71+
72+
Returns
73+
-------
74+
dict of str to np.ndarray
75+
A batch of simulated (and optionally augmented/adapted) data.
76+
"""
3277
batch = self.simulator.sample((self.batch_size,))
3378

79+
if self.augmentations is None:
80+
pass
81+
elif isinstance(self.augmentations, Mapping):
82+
for key, fn in self.augmentations.items():
83+
batch[key] = fn(batch[key])
84+
elif isinstance(self.augmentations, Callable):
85+
batch = self.augmentations(batch)
86+
else:
87+
raise RuntimeError(f"Could not apply augmentations of type {type(self.augmentations)}.")
88+
3489
if self.adapter is not None:
3590
batch = self.adapter(batch, stage=self.stage)
3691

bayesflow/datasets/rounds_dataset.py

Lines changed: 0 additions & 66 deletions
This file was deleted.

bayesflow/experimental/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44

55
from .cif import CIF
66
from .continuous_time_consistency_model import ContinuousTimeConsistencyModel
7+
from .diffusion_model import DiffusionModel
78
from .free_form_flow import FreeFormFlow
89

910
from ..utils._docs import _add_imports_to_all
1011

11-
_add_imports_to_all(include_modules=[])
12+
_add_imports_to_all(include_modules=["diffusion_model"])
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from .diffusion_model import DiffusionModel
2+
from .noise_schedule import NoiseSchedule
3+
from .cosine_noise_schedule import CosineNoiseSchedule
4+
from .edm_noise_schedule import EDMNoiseSchedule
5+
from .dispatch import find_noise_schedule
6+
7+
from ...utils._docs import _add_imports_to_all
8+
9+
_add_imports_to_all(include_modules=[])

0 commit comments

Comments
 (0)