Skip to content

Commit 88c1867

Browse files
Bugfix for to_numpy
1 parent ebecd1f commit 88c1867

File tree

3 files changed

+9
-0
lines changed

3 files changed

+9
-0
lines changed

src/probnum/backend/_core/_jax.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ def is_floating_dtype(dtype) -> bool:
6161

6262

6363
def to_numpy(*arrays: jax.numpy.ndarray) -> Tuple[np.ndarray, ...]:
64+
if len(arrays) == 1:
65+
return np.array(arrays[0])
66+
6467
return tuple(np.array(arr) for arr in arrays)
6568

6669

src/probnum/backend/_core/_numpy.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,9 @@ def is_floating_dtype(dtype) -> bool:
5858

5959

6060
def to_numpy(*arrays: np.ndarray) -> Tuple[np.ndarray, ...]:
61+
if len(arrays) == 1:
62+
return arrays[0]
63+
6164
return tuple(arrays)
6265

6366

src/probnum/backend/_core/_torch.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,9 @@ def cast(a: torch.Tensor, dtype=None, casting="unsafe", copy=None):
158158

159159

160160
def to_numpy(*arrays: torch.Tensor) -> Tuple[np.ndarray, ...]:
161+
if len(arrays) == 1:
162+
return arrays[0].cpu().detach().numpy()
163+
161164
return tuple(arr.cpu().detach().numpy() for arr in arrays)
162165

163166

0 commit comments

Comments
 (0)