Skip to content

Commit 98a6bca

Browse files
committed
squeeze: adapt example, add comment for changing batch dims
1 parent 25b73d3 commit 98a6bca

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

bayesflow/adapters/adapter.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -789,7 +789,8 @@ def squeeze(self, keys: str | Sequence[str], *, axis: int | tuple):
789789
keys : str or Sequence of str
790790
The names of the variables to squeeze.
791791
axis : int or tuple
792-
The axis to squeeze.
792+
The axis to squeeze. As the number of batch dimensions might change, we advise using negative
793+
numbers (i.e., indexing from the end instead of the start).
793794
"""
794795
if isinstance(keys, str):
795796
keys = [keys]

bayesflow/adapters/transforms/squeeze.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,16 @@ class Squeeze(ElementwiseTransform):
1313
Parameters
1414
----------
1515
axis : int or tuple
16-
The axis to squeeze.
16+
The axis to squeeze. As the number of batch dimensions might change, we advise using negative
17+
numbers (i.e., indexing from the end instead of the start).
1718
1819
Examples
1920
--------
2021
shape (3, 1) array:
2122
2223
>>> a = np.array([[1], [2], [3]])
2324
24-
>>> sq = bf.adapters.transforms.Squeeze(axis=1)
25+
>>> sq = bf.adapters.transforms.Squeeze(axis=-1)
2526
>>> sq.forward(a).shape
2627
(3,)
2728

0 commit comments

Comments
 (0)