Skip to content

Commit 08b036d

Browse files
Adapt numpy to backend in arithmetic
1 parent 10789bf commit 08b036d

File tree

1 file changed

+7
-8
lines changed

1 file changed

+7
-8
lines changed

src/probnum/randvars/_arithmetic.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,8 @@
44
import operator
55
from typing import Any, Callable, Dict, Tuple, Union
66

7-
import numpy as np
8-
97
import probnum.linops as _linear_operators
10-
from probnum import utils as _utils
8+
from probnum import backend, utils as _utils
119

1210
from ._constant import Constant as _Constant
1311
from ._normal import Normal as _Normal
@@ -81,7 +79,7 @@ def _apply(
8179
rv1 = _asrandvar(rv1)
8280
rv2 = _asrandvar(rv2)
8381

84-
# Search specific operatir
82+
# Search specific operator
8583
key = (type(rv1), type(rv2))
8684

8785
if key in op_registry:
@@ -125,9 +123,10 @@ def _rv_binary_op(rv1: _RandomVariable, rv2: _RandomVariable) -> _RandomVariable
125123

126124

127125
def _make_rv_binary_op_result_shape_dtype_sample_fn(op_fn, rv1, rv2):
128-
rng = np.random.default_rng(1)
129-
sample_fn = lambda size: op_fn(
130-
rv1.sample(size=size, rng=rng), rv2.sample(size=size, rng=rng)
126+
seed = backend.random.seed(1)
127+
sample_fn = lambda sample_shape: op_fn(
128+
rv1.sample(seed=seed, sample_shape=sample_shape),
129+
rv2.sample(seed=seed, sample_shape=sample_shape),
131130
)
132131

133132
# Infer shape and dtype
@@ -253,7 +252,7 @@ def _mul_normal_constant(
253252
if constant_rv.size == 1:
254253
if constant_rv.support == 0:
255254
return _Constant(
256-
support=np.zeros_like(norm_rv.mean),
255+
support=backend.zeros_like(norm_rv.mean),
257256
)
258257
else:
259258
if norm_rv.cov_cholesky_is_precomputed:

0 commit comments

Comments
 (0)