Skip to content

Commit 2674a41

Browse files
Fix circular imports
1 parent 25263a7 commit 2674a41

File tree

7 files changed

+97
-80
lines changed

7 files changed

+97
-80
lines changed

src/probnum/__init__.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,19 +24,19 @@
2424
LambdaStoppingCriterion,
2525
)
2626

27-
# isort: on
28-
27+
# Supporting packages need to be imported before compat
2928
from . import (
30-
diffeq,
31-
filtsmooth,
32-
linalg,
3329
linops,
34-
problems,
35-
quad,
3630
randprocs,
3731
randvars,
38-
utils,
3932
)
33+
34+
# Compatibility functionality between backend, linops and randvars
35+
from . import compat
36+
37+
# isort: on
38+
39+
from . import diffeq, filtsmooth, linalg, problems, quad, utils
4040
from ._version import version as __version__
4141
from .randvars import asrandvar
4242

src/probnum/backend/_core/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ def as_scalar(x: ScalarArgType, dtype: DTypeArgType = None) -> ArrayType:
111111
"cast",
112112
"promote_types",
113113
"is_floating",
114+
"is_floating_dtype",
114115
"finfo",
115116
# Shape Arithmetic
116117
"reshape",

src/probnum/compat/_core.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from typing import Tuple, Union
22

33
import numpy as np
4+
import scipy.sparse
45

5-
from probnum import backend, linops
6+
from probnum import backend, linops, randvars
67

78
__all__ = [
89
"to_numpy",
@@ -34,3 +35,60 @@ def cast(a, dtype=None, casting="unsafe", copy=None):
3435
return a.astype(dtype=dtype, casting=casting, copy=copy)
3536

3637
return backend.cast(a, dtype=dtype, casting=casting, copy=copy)
38+
39+
40+
def atleast_1d(
41+
*objs: Union[
42+
backend.ndarray,
43+
linops.LinearOperator,
44+
randvars.RandomVariable,
45+
]
46+
) -> Union[
47+
Union[
48+
backend.ndarray,
49+
linops.LinearOperator,
50+
randvars.RandomVariable,
51+
],
52+
Tuple[
53+
Union[
54+
backend.ndarray,
55+
linops.LinearOperator,
56+
randvars.RandomVariable,
57+
],
58+
...,
59+
],
60+
]:
61+
"""Reshape arrays, linear operators and random variables to have at least 1
62+
dimension.
63+
64+
Scalar inputs are converted to 1-dimensional arrays, whilst
65+
higher-dimensional inputs are preserved.
66+
67+
Parameters
68+
----------
69+
objs:
70+
One or more input linear operators, random variables or arrays.
71+
72+
Returns
73+
-------
74+
res :
75+
An array / random variable / linop or tuple of arrays / random variables /
76+
linear operators, each with ``a.ndim >= 1``.
77+
"""
78+
res = []
79+
80+
for obj in objs:
81+
if isinstance(obj, np.ndarray):
82+
obj = np.atleast_1d(obj)
83+
elif isinstance(obj, backend.ndarray):
84+
obj = backend.atleast_1d(obj)
85+
elif isinstance(obj, randvars.RandomVariable):
86+
if obj.ndim == 0:
87+
obj = obj.reshape((1,))
88+
89+
res.append(obj)
90+
91+
if len(res) == 1:
92+
return res[0]
93+
94+
return tuple(res)

src/probnum/linalg/_problinsolve.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import scipy.sparse
1414

1515
import probnum # pylint: disable=unused-import
16-
from probnum import linops, randvars, utils
16+
from probnum import linops, randvars
1717
from probnum.linalg.solvers.matrixbased import SymmetricMatrixBasedSolver
1818
from probnum.typing import LinearOperatorArgType
1919

@@ -199,7 +199,7 @@ def problinsolve(
199199
# Select and initialize solver
200200
linear_solver = _init_solver(
201201
A=A,
202-
b=utils.as_colvec(b[:, i]),
202+
b=as_colvec(b[:, i]),
203203
A0=A0,
204204
Ainv0=Ainv0,
205205
x0=x,
@@ -342,9 +342,9 @@ def _preprocess_linear_system(A, b, x0=None):
342342
"""
343343
# Transform linear system to correct dimensions
344344
if not isinstance(b, randvars.RandomVariable):
345-
b = utils.as_colvec(b) # (n,) -> (n, 1)
345+
b = as_colvec(b) # (n,) -> (n, 1)
346346
if x0 is not None:
347-
x0 = utils.as_colvec(x0) # (n,) -> (n, 1)
347+
x0 = as_colvec(x0) # (n,) -> (n, 1)
348348

349349
return A, b, x0
350350

@@ -475,3 +475,24 @@ def _postprocess(info, A):
475475
scipy.linalg.LinAlgWarning,
476476
stacklevel=3,
477477
)
478+
479+
480+
def as_colvec(
481+
vec: Union[np.ndarray, "probnum.randvars.RandomVariable"]
482+
) -> Union[np.ndarray, "probnum.randvars.RandomVariable"]:
483+
"""Transform the given vector or random variable to column format. Given a vector
484+
(or random variable) of dimension (n,) return an array with dimensions (n, 1)
485+
instead. Higher-dimensional arrays are not changed.
486+
487+
Parameters
488+
----------
489+
vec
490+
Vector, array or random variable to be transformed into a column vector.
491+
"""
492+
if isinstance(vec, probnum.randvars.RandomVariable):
493+
if vec.shape != (vec.shape[0], 1):
494+
vec.reshape(newshape=(vec.shape[0], 1))
495+
else:
496+
if vec.ndim == 1:
497+
return vec[:, None]
498+
return vec

src/probnum/randvars/_normal.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import operator
55
from typing import Optional, Union
66

7-
from probnum import backend, compat, config, linops
7+
from probnum import backend, config, linops
88
from probnum.typing import (
99
ArrayLike,
1010
ArrayLikeGetitemArgType,
@@ -82,6 +82,9 @@ def __init__(
8282
if not backend.is_floating_dtype(dtype):
8383
dtype = backend.double
8484

85+
# Circular dependency -> defer import
86+
from probnum import compat # pylint: disable=import-outside-toplevel
87+
8588
mean = compat.cast(mean, dtype=dtype, casting="safe", copy=False)
8689
cov = compat.cast(cov, dtype=dtype, casting="safe", copy=False)
8790

src/probnum/utils/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
11
"""Utility Functions."""
22

33
from .argutils import *
4-
from .arrayutils import *
54

65
# Public classes and functions. Order is reflected in documentation.
76
__all__ = [
8-
"as_colvec",
9-
"atleast_1d",
107
"as_numpy_scalar",
118
"as_shape",
129
]

src/probnum/utils/arrayutils.py

Lines changed: 0 additions & 63 deletions
This file was deleted.

0 commit comments

Comments
 (0)