This repository was archived by the owner on Nov 17, 2025. It is now read-only.
File tree Expand file tree Collapse file tree 4 files changed +10
-5
lines changed
Expand file tree Collapse file tree 4 files changed +10
-5
lines changed Original file line number Diff line number Diff line change 11from typing import Callable , Tuple
22
33import aesara .tensor as at
4- import aesara .tensor .slinalg as slinalg
54from aesara .tensor .random .utils import RandomStream
65from aesara .tensor .shape import shape_tuple
6+ from aesara .tensor .slinalg import cholesky , solve_triangular
77from aesara .tensor .var import TensorVariable
88
99
@@ -51,9 +51,9 @@ def gaussian_metric(
5151 dot , matmul = at .dot , lambda x , y : x * y
5252 elif inverse_mass_matrix .ndim == 2 :
5353 shape = (shape_tuple (inverse_mass_matrix )[0 ],)
54- tril_inv = slinalg . cholesky (inverse_mass_matrix )
54+ tril_inv = cholesky (inverse_mass_matrix )
5555 identity = at .eye (* shape )
56- mass_matrix_sqrt = slinalg . solve_lower_triangular (tril_inv , identity )
56+ mass_matrix_sqrt = solve_triangular (tril_inv , identity , lower = True )
5757 dot , matmul = at .dot , at .dot
5858 else :
5959 raise ValueError (
Original file line number Diff line number Diff line change @@ -40,7 +40,7 @@ def update(initial_energy, state):
4040
4141 delta_energy = initial_energy - new_energy
4242 delta_energy = at .where (at .isnan (delta_energy ), - np .inf , delta_energy )
43- is_transition_divergent = at .abs_ (delta_energy ) > divergence_threshold
43+ is_transition_divergent = at .abs (delta_energy ) > divergence_threshold
4444
4545 weight = delta_energy
4646 log_p_accept = at .where (
Original file line number Diff line number Diff line change 55from aesara .graph .basic import Variable , ancestors
66from aesara .graph .fg import FunctionGraph
77from aesara .graph .rewriting .utils import rewrite_graph
8- from aesara .tensor .rewriting .shape import ShapeFeature
8+ from aesara .tensor .rewriting .basic import ShapeFeature
99from aesara .tensor .var import TensorVariable
1010
1111
Original file line number Diff line number Diff line change @@ -16,6 +16,11 @@ convention = numpy
1616[tool:pytest]
1717python_files =test*.py
1818testpaths =tests
19+ filterwarnings =
20+ error:::aesara
21+ error:::aeppl
22+ error:::aemcmc
23+ ignore:::xarray
1924
2025[coverage:run]
2126omit =
You can’t perform that action at this time.
0 commit comments