Skip to content

Commit 5659773

Browse files
committed
fix type hint in squeeze [no ci]
1 parent 664f00a commit 5659773

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

bayesflow/adapters/adapter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -820,7 +820,7 @@ def split(self, key: str, *, into: Sequence[str], indices_or_sections: int | Seq
820820

821821
return self
822822

823-
def squeeze(self, keys: str | Sequence[str], *, axis: int | tuple):
823+
def squeeze(self, keys: str | Sequence[str], *, axis: int | Sequence[int]):
824824
"""Append a :py:class:`~transforms.Squeeze` transform to the adapter.
825825
826826
Parameters

bayesflow/adapters/transforms/squeeze.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22

3+
from collections.abc import Sequence
34
from bayesflow.utils.serialization import serializable, serialize
45

56
from .elementwise_transform import ElementwiseTransform
@@ -29,8 +30,10 @@ class Squeeze(ElementwiseTransform):
2930
It is recommended to precede this transform with a :class:`~bayesflow.adapters.transforms.ToArray` transform.
3031
"""
3132

32-
def __init__(self, *, axis: int | tuple):
33+
def __init__(self, *, axis: int | Sequence[int]):
3334
super().__init__()
35+
if isinstance(axis, Sequence):
36+
axis = tuple(axis)
3437
self.axis = axis
3538

3639
def get_config(self) -> dict:

0 commit comments

Comments
 (0)