|
4 | 4 | import operator |
5 | 5 | from typing import Any, Callable, Dict, Tuple, Union |
6 | 6 |
|
7 | | -import numpy as np |
8 | | - |
9 | 7 | import probnum.linops as _linear_operators |
10 | | -from probnum import utils as _utils |
| 8 | +from probnum import backend, utils as _utils |
11 | 9 |
|
12 | 10 | from ._constant import Constant as _Constant |
13 | 11 | from ._normal import Normal as _Normal |
@@ -81,7 +79,7 @@ def _apply( |
81 | 79 | rv1 = _asrandvar(rv1) |
82 | 80 | rv2 = _asrandvar(rv2) |
83 | 81 |
|
84 | | - # Search specific operatir |
| 82 | + # Search specific operator |
85 | 83 | key = (type(rv1), type(rv2)) |
86 | 84 |
|
87 | 85 | if key in op_registry: |
@@ -125,9 +123,10 @@ def _rv_binary_op(rv1: _RandomVariable, rv2: _RandomVariable) -> _RandomVariable |
125 | 123 |
|
126 | 124 |
|
127 | 125 | def _make_rv_binary_op_result_shape_dtype_sample_fn(op_fn, rv1, rv2): |
128 | | - rng = np.random.default_rng(1) |
129 | | - sample_fn = lambda size: op_fn( |
130 | | - rv1.sample(size=size, rng=rng), rv2.sample(size=size, rng=rng) |
| 126 | + seed = backend.random.seed(1) |
| 127 | + sample_fn = lambda sample_shape: op_fn( |
| 128 | + rv1.sample(seed=seed, sample_shape=sample_shape), |
| 129 | + rv2.sample(seed=seed, sample_shape=sample_shape), |
131 | 130 | ) |
132 | 131 |
|
133 | 132 | # Infer shape and dtype |
@@ -253,7 +252,7 @@ def _mul_normal_constant( |
253 | 252 | if constant_rv.size == 1: |
254 | 253 | if constant_rv.support == 0: |
255 | 254 | return _Constant( |
256 | | - support=np.zeros_like(norm_rv.mean), |
| 255 | + support=backend.zeros_like(norm_rv.mean), |
257 | 256 | ) |
258 | 257 | else: |
259 | 258 | if norm_rv.cov_cholesky_is_precomputed: |
|
0 commit comments