Skip to content

Commit b4d0a72

Browse files
eodoleLarsKue
andauthored
Subset arrays (#411)
* made initial backend functions for adapter subsetting, need to still make the squeeze function and link it to the front end * added subsample functionality, to do would be adding them to testing procedures * made the take function and ran the linter * changed name of subsampling function * changed documentation, to be consistent with external notation, rather than internal shorthand * small formation change to documentation * changed subsample to have sample size and axis in the constructor * moved transforms in the adapter.py so they're in alphabetical order like the other transforms * changed random_subsample to maptransform rather than filter transform * updated documentation with new naming convention * added arguments of take to the constructor * added feature to specify a percentage of the data to subsample rather than only integer input * changed subsample in adapter.py to allow float as an input for the sample size * renamed subsample_array and associated classes/functions to RandomSubsample and random_subsample respectively * included TypeError to force users to only subsample one dataset at a time * ran linter * rerun formatter * clean up random subsample transform and docs * clean up take transform and docs * nitpick clean-up * skip shape check for subsampled adapter transform inverse * fix serialization of new transforms * skip randomly subsampled key in serialization consistency check --------- Co-authored-by: LarsKue <lars@kuehmichel.de>
1 parent 52bdb58 commit b4d0a72

File tree

7 files changed

+159
-11
lines changed

7 files changed

+159
-11
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,6 @@ docs/
3939

4040
# MacOS
4141
.DS_Store
42+
43+
# Rproj
44+
.Rproj.user

bayesflow/adapters/adapter.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
Standardize,
2626
ToArray,
2727
Transform,
28+
RandomSubsample,
29+
Take,
2830
)
2931
from .transforms.filter_transform import Predicate
3032

@@ -665,6 +667,28 @@ def one_hot(self, keys: str | Sequence[str], num_classes: int):
665667
self.transforms.append(transform)
666668
return self
667669

670+
def random_subsample(self, key: str, *, sample_size: int | float, axis: int = -1):
671+
"""
672+
Append a :py:class:`~transforms.RandomSubsample` transform to the adapter.
673+
674+
Parameters
675+
----------
676+
key : str or Sequence of str
677+
The name of the variable to subsample.
678+
sample_size : int or float
679+
The number of samples to draw, or a fraction between 0 and 1 of the total number of samples to draw.
680+
axis: int, optional
681+
Which axis to draw samples over. The last axis is used by default.
682+
"""
683+
684+
if not isinstance(key, str):
685+
raise TypeError("Can only subsample one batch entry at a time.")
686+
687+
transform = MapTransform({key: RandomSubsample(sample_size=sample_size, axis=axis)})
688+
689+
self.transforms.append(transform)
690+
return self
691+
668692
def rename(self, from_key: str, to_key: str):
669693
"""Append a :py:class:`~transforms.Rename` transform to the adapter.
670694
@@ -741,7 +765,7 @@ def standardize(
741765
Names of variables to include in the transform.
742766
exclude : str or Sequence of str, optional
743767
Names of variables to exclude from the transform.
744-
**kwargs : dict
768+
**kwargs :
745769
Additional keyword arguments passed to the transform.
746770
"""
747771
transform = FilterTransform(
@@ -754,6 +778,42 @@ def standardize(
754778
self.transforms.append(transform)
755779
return self
756780

781+
def take(
782+
self,
783+
include: str | Sequence[str] = None,
784+
*,
785+
indices: Sequence[int],
786+
axis: int = -1,
787+
predicate: Predicate = None,
788+
exclude: str | Sequence[str] = None,
789+
):
790+
"""
791+
Append a :py:class:`~transforms.Take` transform to the adapter.
792+
793+
Parameters
794+
----------
795+
include : str or Sequence of str, optional
796+
Names of variables to include in the transform.
797+
indices : Sequence of int
798+
Which indices to take from the data.
799+
axis : int, optional
800+
Which axis to take from. The last axis is used by default.
801+
predicate : Predicate, optional
802+
Function that indicates which variables should be transformed.
803+
exclude : str or Sequence of str, optional
804+
Names of variables to exclude from the transform.
805+
"""
806+
transform = FilterTransform(
807+
transform_constructor=Take,
808+
predicate=predicate,
809+
include=include,
810+
exclude=exclude,
811+
indices=indices,
812+
axis=axis,
813+
)
814+
self.transforms.append(transform)
815+
return self
816+
757817
def to_array(
758818
self,
759819
include: str | Sequence[str] = None,

bayesflow/adapters/transforms/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
from .to_array import ToArray
2424
from .to_dict import ToDict
2525
from .transform import Transform
26+
from .random_subsample import RandomSubsample
27+
from .take import Take
2628

2729
from ...utils._docs import _add_imports_to_all
2830

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import numpy as np
2+
from bayesflow.utils.serialization import serializable, serialize
3+
from .elementwise_transform import ElementwiseTransform
4+
5+
6+
@serializable(package="bayesflow.adapters")
7+
class RandomSubsample(ElementwiseTransform):
8+
"""
9+
A transform that takes a random subsample of the data within an axis.
10+
11+
Example: adapter.random_subsample("x", sample_size = 3, axis = -1)
12+
13+
"""
14+
15+
def __init__(
16+
self,
17+
sample_size: int | float,
18+
axis: int = -1,
19+
):
20+
super().__init__()
21+
if isinstance(sample_size, float):
22+
if sample_size <= 0 or sample_size >= 1:
23+
ValueError("Sample size as a percentage must be a float between 0 and 1 exclusive. ")
24+
self.sample_size = sample_size
25+
self.axis = axis
26+
27+
def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
28+
axis = self.axis
29+
max_sample_size = data.shape[axis]
30+
31+
if isinstance(self.sample_size, int):
32+
sample_size = self.sample_size
33+
else:
34+
sample_size = np.round(self.sample_size * max_sample_size)
35+
36+
# random sample without replacement
37+
sample_indices = np.random.permutation(max_sample_size)[0 : sample_size - 1]
38+
39+
return np.take(data, sample_indices, axis)
40+
41+
def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
42+
# non invertible transform
43+
return data
44+
45+
def get_config(self) -> dict:
46+
config = {"sample_size": self.sample_size, "axis": self.axis}
47+
48+
return serialize(config)
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from collections.abc import Sequence
2+
import numpy as np
3+
4+
from bayesflow.utils.serialization import serializable, serialize
5+
6+
from .elementwise_transform import ElementwiseTransform
7+
8+
9+
@serializable(package="bayesflow.adapters")
10+
class Take(ElementwiseTransform):
11+
"""
12+
A transform to reduce the dimensionality of arrays output by the summary network
13+
Example: adapter.take("x", np.arange(0,3), axis=-1)
14+
"""
15+
16+
def __init__(self, indices: Sequence[int], axis: int = -1):
17+
super().__init__()
18+
self.indices = indices
19+
self.axis = axis
20+
21+
def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
22+
return np.take(data, self.indices, self.axis)
23+
24+
def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
25+
# not a true invertible function
26+
return data
27+
28+
def get_config(self) -> dict:
29+
config = {"indices": self.indices, "axis": self.axis}
30+
31+
return serialize(config)

tests/test_adapters/conftest.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def adapter():
1111
def serializable_fn(x):
1212
return x
1313

14-
d = (
14+
return (
1515
Adapter()
1616
.to_array()
1717
.as_set(["s1", "s2"])
@@ -32,12 +32,12 @@ def serializable_fn(x):
3232
.standardize(exclude=["t1", "t2", "o1"])
3333
.drop("d1")
3434
.one_hot("o1", 10)
35-
.keep(["x", "y", "z1", "p1", "p2", "s1", "s2", "t1", "t2", "o1", "split_1", "split_2"])
35+
.keep(["x", "y", "z1", "p1", "p2", "s1", "s2", "s3", "t1", "t2", "o1", "split_1", "split_2"])
3636
.rename("o1", "o2")
37+
.random_subsample("s3", sample_size=33, axis=0)
38+
.take("s3", indices=np.arange(0, 32), axis=0)
3739
)
3840

39-
return d
40-
4141

4242
@pytest.fixture()
4343
def random_data():
@@ -58,6 +58,7 @@ def random_data():
5858
"d1": np.random.standard_normal(size=(32, 2)),
5959
"d2": np.random.standard_normal(size=(32, 2)),
6060
"o1": np.random.randint(0, 9, size=(32, 2)),
61+
"s3": np.random.standard_normal(size=(35, 2)),
6162
"u1": np.random.uniform(low=-1, high=2, size=(32, 1)),
6263
"key_to_split": np.random.standard_normal(size=(32, 10)),
6364
}
@@ -67,7 +68,7 @@ def random_data():
6768
def adapter_log_det_jac():
6869
from bayesflow.adapters import Adapter
6970

70-
adapter = (
71+
return (
7172
Adapter()
7273
.scale("x1", by=2)
7374
.log("p1", p1=True)
@@ -79,14 +80,12 @@ def adapter_log_det_jac():
7980
.rename("u1", "u")
8081
)
8182

82-
return adapter
83-
8483

8584
@pytest.fixture()
8685
def adapter_log_det_jac_inverse():
8786
from bayesflow.adapters import Adapter
8887

89-
adapter = (
88+
return (
9089
Adapter()
9190
.standardize("x1", mean=1, std=2)
9291
.log("p1")
@@ -96,5 +95,3 @@ def adapter_log_det_jac_inverse():
9695
.constrain("u1", lower=-1, upper=2)
9796
.scale(["p1", "p2", "p3"], by=3.5)
9897
)
99-
100-
return adapter

tests/test_adapters/test_adapters.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ def test_cycle_consistency(adapter, random_data):
1616
if key in ["d1", "d2", "p3", "n1", "u1"]:
1717
# dropped
1818
continue
19+
if key == "s3":
20+
# we subsampled this key, so it is expected for its shape to change
21+
continue
1922
assert key in deprocessed
2023
assert np.allclose(value, deprocessed[key])
2124

@@ -31,6 +34,10 @@ def test_serialize_deserialize(adapter, random_data):
3134
random_data["foo"] = random_data["x1"]
3235
deserialized_processed = deserialized(random_data)
3336
for key, value in processed.items():
37+
if key == "s3":
38+
# skip this key because it is *randomly* subsampled
39+
continue
40+
3441
assert np.allclose(value, deserialized_processed[key])
3542

3643

0 commit comments

Comments
 (0)