Skip to content

Commit 5d3b135

Browse files
Add einsum
1 parent ea228e8 commit 5d3b135

File tree

5 files changed

+74
-47
lines changed

5 files changed

+74
-47
lines changed

src/probnum/backend/__init__.py

Lines changed: 7 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,5 @@
11
from ._select import Backend, select_backend as _select_backend
22

3-
# pylint: disable=undefined-all-variable
4-
__all__ = [
5-
"ndarray",
6-
# DTypes
7-
"dtype",
8-
"asdtype",
9-
"bool",
10-
"int32",
11-
"int64",
12-
"single",
13-
"double",
14-
"csingle",
15-
"cdouble",
16-
"cast",
17-
"promote_types",
18-
"is_floating",
19-
"finfo",
20-
# Shape Arithmetic
21-
"reshape",
22-
"atleast_1d",
23-
"atleast_2d",
24-
"broadcast_arrays",
25-
"broadcast_shapes",
26-
"ndim",
27-
"swapaxes",
28-
# Constructors
29-
"array",
30-
"asarray",
31-
"diag",
32-
"eye",
33-
"ones",
34-
"ones_like",
35-
"zeros",
36-
"zeros_like",
37-
"linspace",
38-
# Constants
39-
"pi",
40-
"inf",
41-
# Operations
42-
"sin",
43-
"exp",
44-
"log",
45-
"sqrt",
46-
"sum",
47-
"maximum",
48-
]
49-
503
BACKEND = _select_backend()
514

525
# isort: off
@@ -56,10 +9,17 @@
569
from ._core import *
5710

5811
from . import (
12+
_core,
5913
autodiff,
6014
linalg,
6115
random,
6216
special,
6317
)
6418

6519
# isort: on
20+
21+
__all__ = [
22+
"Backend",
23+
"BACKEND",
24+
"Dispatcher",
25+
] + _core.__all__

src/probnum/backend/_core/__init__.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@
6464
# Element-wise Binary Operations
6565
maximum = _core.maximum
6666

67+
# Contractions
68+
einsum = _core.einsum
69+
6770
# Reductions
6871
all = _core.all
6972
sum = _core.sum
@@ -91,3 +94,64 @@ def as_scalar(x: ScalarArgType, dtype: DTypeArgType = None) -> ArrayType:
9194
raise ValueError("The given input is not a scalar.")
9295

9396
return asarray(x, dtype=dtype)[()]
97+
98+
99+
__all__ = [
100+
"ndarray",
101+
# DTypes
102+
"dtype",
103+
"asdtype",
104+
"bool",
105+
"int32",
106+
"int64",
107+
"single",
108+
"double",
109+
"csingle",
110+
"cdouble",
111+
"cast",
112+
"promote_types",
113+
"is_floating",
114+
"finfo",
115+
# Shape Arithmetic
116+
"reshape",
117+
"atleast_1d",
118+
"atleast_2d",
119+
"broadcast_arrays",
120+
"broadcast_shapes",
121+
"ndim",
122+
"swapaxes",
123+
# Constructors
124+
"array",
125+
"asarray",
126+
"as_scalar",
127+
"diag",
128+
"eye",
129+
"full",
130+
"full_like",
131+
"ones",
132+
"ones_like",
133+
"zeros",
134+
"zeros_like",
135+
"linspace",
136+
# Constants
137+
"inf",
138+
"pi",
139+
# Element-wise Unary Operations
140+
"exp",
141+
"isfinite",
142+
"log",
143+
"sin",
144+
"sqrt",
145+
# Element-wise Binary Operations
146+
"maximum",
147+
# Contractions
148+
"einsum",
149+
# Reductions
150+
"all",
151+
"sum",
152+
# Misc
153+
"to_numpy",
154+
# Just-in-Time Compilation
155+
"jit",
156+
"jit_method",
157+
]

src/probnum/backend/_core/_jax.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
double,
1818
dtype,
1919
dtype as asdtype,
20+
einsum,
2021
exp,
2122
eye,
2223
finfo,

src/probnum/backend/_core/_numpy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
double,
1717
dtype,
1818
dtype as asdtype,
19+
einsum,
1920
exp,
2021
eye,
2122
finfo,

src/probnum/backend/_core/_torch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
diag,
1616
double,
1717
dtype,
18+
einsum,
1819
exp,
1920
eye,
2021
finfo,

0 commit comments

Comments
 (0)