Skip to content

Commit dcc624a

Browse files
Update PyTensor dependency
Co-authored-by: Jesse Grabowski <48652735+jessegrabowski@users.noreply.github.com>
1 parent 931f89f commit dcc624a

26 files changed

+41
-38
lines changed

conda-envs/environment-alternative-backends.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ dependencies:
2222
- numpyro>=0.8.0
2323
- pandas>=0.24.0
2424
- pip
25-
- pytensor>=2.32.0,<2.33
25+
- pytensor>=2.34.0,<2.35
2626
- python-graphviz
2727
- networkx
2828
- rich>=13.7.1

conda-envs/environment-dev.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ dependencies:
1212
- numpy>=1.25.0
1313
- pandas>=0.24.0
1414
- pip
15-
- pytensor>=2.32.0,<2.33
15+
- pytensor>=2.34.0,<2.35
1616
- python-graphviz
1717
- networkx
1818
- scipy>=1.4.1

conda-envs/environment-docs.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ dependencies:
1111
- numpy>=1.25.0
1212
- pandas>=0.24.0
1313
- pip
14-
- pytensor>=2.32.0,<2.33
14+
- pytensor>=2.34.0,<2.35
1515
- python-graphviz
1616
- rich>=13.7.1
1717
- scipy>=1.4.1

conda-envs/environment-test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ dependencies:
1414
- pandas>=0.24.0
1515
- pip
1616
- polyagamma
17-
- pytensor>=2.32.0,<2.33
17+
- pytensor>=2.34.0,<2.35
1818
- python-graphviz
1919
- networkx
2020
- rich>=13.7.1

conda-envs/windows-environment-dev.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ dependencies:
1212
- numpy>=1.25.0
1313
- pandas>=0.24.0
1414
- pip
15-
- pytensor>=2.32.0,<2.33
15+
- pytensor>=2.34.0,<2.35
1616
- python-graphviz
1717
- networkx
1818
- rich>=13.7.1

conda-envs/windows-environment-test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ dependencies:
1515
- pandas>=0.24.0
1616
- pip
1717
- polyagamma
18-
- pytensor>=2.32.0,<2.33
18+
- pytensor>=2.34.0,<2.35
1919
- python-graphviz
2020
- networkx
2121
- rich>=13.7.1

pymc/distributions/custom.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818

1919
from pytensor import Variable, clone_replace
2020
from pytensor import tensor as pt
21-
from pytensor.graph.basic import io_toposort
2221
from pytensor.graph.features import ReplaceValidate
2322
from pytensor.graph.rewriting.basic import GraphRewriter
23+
from pytensor.graph.traversal import io_toposort
2424
from pytensor.scan.op import Scan
2525
from pytensor.tensor import TensorVariable, as_tensor_variable
2626
from pytensor.tensor.random.op import RandomVariable

pymc/distributions/timeseries.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@
2121
import pytensor
2222
import pytensor.tensor as pt
2323

24-
from pytensor.graph.basic import Apply, ancestors
24+
from pytensor.graph.basic import Apply
2525
from pytensor.graph.replace import clone_replace
26+
from pytensor.graph.traversal import ancestors
2627
from pytensor.tensor import TensorVariable
2728
from pytensor.tensor.random.op import RandomVariable
2829
from pytensor.tensor.random.utils import normalize_size_param

pymc/initial_point.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@
2222

2323
from pytensor import graph_replace
2424
from pytensor.compile.ops import TypeCastingOp
25-
from pytensor.graph.basic import Apply, Variable, ancestors, walk
25+
from pytensor.graph.basic import Apply, Variable
2626
from pytensor.graph.fg import FunctionGraph
2727
from pytensor.graph.rewriting.db import RewriteDatabaseQuery, SequenceDB
28+
from pytensor.graph.traversal import ancestors, walk
2829
from pytensor.tensor.variable import TensorVariable
2930

3031
from pymc.logprob.transforms import Transform

pymc/logprob/basic.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,9 @@
4545
from pytensor.graph.basic import (
4646
Constant,
4747
Variable,
48-
ancestors,
49-
walk,
5048
)
5149
from pytensor.graph.rewriting.basic import GraphRewriter, NodeRewriter
50+
from pytensor.graph.traversal import ancestors, walk
5251
from pytensor.tensor.variable import TensorVariable
5352

5453
from pymc.logprob.abstract import (
@@ -533,8 +532,8 @@ def conditional_logp(
533532
f"The logprob terms of the following value variables could not be derived: {missing_value_terms}"
534533
)
535534

536-
values, logprobs = zip(*values_to_logprobs.items())
537-
logprobs = cleanup_ir(logprobs)
535+
# Ensure same order as input
536+
logprobs = cleanup_ir(tuple(values_to_logprobs[v] for v in original_values))
538537

539538
if warn_rvs:
540539
rvs_in_logp_expressions = _find_unallowed_rvs_in_graph(logprobs)
@@ -545,7 +544,7 @@ def conditional_logp(
545544
UserWarning,
546545
)
547546

548-
return dict(zip(values, logprobs))
547+
return dict(zip(original_values, logprobs))
549548

550549

551550
def transformed_conditional_logp(

0 commit comments

Comments
 (0)