Skip to content

Commit 25b73d3

Browse files
committed
Add squeeze transform
Very basic transform, just the inverse of expand_dims
1 parent 92426d6 commit 25b73d3

File tree

4 files changed

+62
-0
lines changed

4 files changed

+62
-0
lines changed

bayesflow/adapters/adapter.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
OneHot,
2323
Rename,
2424
SerializableCustomTransform,
25+
Squeeze,
2526
Sqrt,
2627
Standardize,
2728
ToArray,
@@ -780,6 +781,23 @@ def split(self, key: str, *, into: Sequence[str], indices_or_sections: int | Seq
780781

781782
return self
782783

784+
def squeeze(self, keys: str | Sequence[str], *, axis: int | tuple):
785+
"""Append a :py:class:`~transforms.Squeeze` transform to the adapter.
786+
787+
Parameters
788+
----------
789+
keys : str or Sequence of str
790+
The names of the variables to squeeze.
791+
axis : int or tuple
792+
The axis to squeeze.
793+
"""
794+
if isinstance(keys, str):
795+
keys = [keys]
796+
797+
transform = MapTransform({key: Squeeze(axis=axis) for key in keys})
798+
self.transforms.append(transform)
799+
return self
800+
783801
def sqrt(self, keys: str | Sequence[str]):
784802
"""Append an :py:class:`~transforms.Sqrt` transform to the adapter.
785803

bayesflow/adapters/transforms/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .serializable_custom_transform import SerializableCustomTransform
2020
from .shift import Shift
2121
from .split import Split
22+
from .squeeze import Squeeze
2223
from .sqrt import Sqrt
2324
from .standardize import Standardize
2425
from .to_array import ToArray
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import numpy as np
2+
3+
from bayesflow.utils.serialization import serializable, serialize
4+
5+
from .elementwise_transform import ElementwiseTransform
6+
7+
8+
@serializable("bayesflow.adapters")
9+
class Squeeze(ElementwiseTransform):
10+
"""
11+
Squeeze dimensions of an array.
12+
13+
Parameters
14+
----------
15+
axis : int or tuple
16+
The axis to squeeze.
17+
18+
Examples
19+
--------
20+
shape (3, 1) array:
21+
22+
>>> a = np.array([[1], [2], [3]])
23+
24+
>>> sq = bf.adapters.transforms.Squeeze(axis=1)
25+
>>> sq.forward(a).shape
26+
(3,)
27+
28+
It is recommended to precede this transform with a :class:`~bayesflow.adapters.transforms.ToArray` transform.
29+
"""
30+
31+
def __init__(self, *, axis: int | tuple):
32+
super().__init__()
33+
self.axis = axis
34+
35+
def get_config(self) -> dict:
36+
return serialize({"axis": self.axis})
37+
38+
def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
39+
return np.squeeze(data, axis=self.axis)
40+
41+
def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
42+
return np.expand_dims(data, axis=self.axis)

tests/test_adapters/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def serializable_fn(x):
2121
.concatenate(["x1", "x2"], into="x")
2222
.concatenate(["y1", "y2"], into="y")
2323
.expand_dims(["z1"], axis=2)
24+
.squeeze("z1", axis=2)
2425
.log("p1")
2526
.constrain("p2", lower=0)
2627
.apply(include="p2", forward="exp", inverse="log")

0 commit comments

Comments
 (0)