Skip to content

Commit 5a83a9f

Browse files
committed
Do not guess order of inputs for users
1 parent 9709271 commit 5a83a9f

File tree

4 files changed

+42
-17
lines changed

4 files changed

+42
-17
lines changed

pymc/model/core.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@
6262
convert_observed_data,
6363
gradient,
6464
hessian,
65-
inputvars,
6665
join_nonshared_inputs,
6766
rewrite_pregrad,
6867
)
@@ -599,7 +598,11 @@ def compile_logp(
599598
Whether to sum all logp terms or return elemwise logp for each variable.
600599
Defaults to True.
601600
"""
602-
return self.compile_fn(self.logp(vars=vars, jacobian=jacobian, sum=sum), **compile_kwargs)
601+
return self.compile_fn(
602+
inputs=self.value_vars,
603+
outs=self.logp(vars=vars, jacobian=jacobian, sum=sum),
604+
**compile_kwargs,
605+
)
603606

604607
def compile_dlogp(
605608
self,
@@ -617,7 +620,11 @@ def compile_dlogp(
617620
jacobian : bool
618621
Whether to include jacobian terms in logprob graph. Defaults to True.
619622
"""
620-
return self.compile_fn(self.dlogp(vars=vars, jacobian=jacobian), **compile_kwargs)
623+
return self.compile_fn(
624+
inputs=self.value_vars,
625+
outs=self.dlogp(vars=vars, jacobian=jacobian),
626+
**compile_kwargs,
627+
)
621628

622629
def compile_d2logp(
623630
self,
@@ -637,7 +644,8 @@ def compile_d2logp(
637644
Whether to include jacobian terms in logprob graph. Defaults to True.
638645
"""
639646
return self.compile_fn(
640-
self.d2logp(vars=vars, jacobian=jacobian, negate_output=negate_output),
647+
inputs=self.value_vars,
648+
outs=self.d2logp(vars=vars, jacobian=jacobian, negate_output=negate_output),
641649
**compile_kwargs,
642650
)
643651

@@ -742,7 +750,7 @@ def dlogp(
742750
dlogp graph
743751
"""
744752
if vars is None:
745-
value_vars = None
753+
value_vars = self.continuous_value_vars
746754
else:
747755
if not isinstance(vars, list | tuple):
748756
vars = [vars]
@@ -782,7 +790,7 @@ def d2logp(
782790
d²logp graph
783791
"""
784792
if vars is None:
785-
value_vars = None
793+
value_vars = self.continuous_value_vars
786794
else:
787795
if not isinstance(vars, list | tuple):
788796
vars = [vars]
@@ -1630,7 +1638,10 @@ def compile_fn(
16301638
Compiled PyTensor function
16311639
"""
16321640
if inputs is None:
1633-
inputs = inputvars(outs)
1641+
if len(inputs) > 1:
1642+
raise ValueError(
1643+
"compile_fn requires inputs to be specified when there is more than one input."
1644+
)
16341645

16351646
with self:
16361647
fn = compile(
@@ -1793,7 +1804,7 @@ def point_logps(self, point=None, round_vals=2, **kwargs):
17931804
factor.name: np.round(np.asarray(factor_logp), round_vals)
17941805
for factor, factor_logp in zip(
17951806
factors,
1796-
self.compile_fn(factor_logps_fn, **kwargs)(point),
1807+
self.compile_fn(inputs=self.value_vars, outs=factor_logps_fn, **kwargs)(point),
17971808
)
17981809
}
17991810

pymc/pytensorf.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
)
3636
from pytensor.graph.fg import FunctionGraph, Output
3737
from pytensor.graph.op import HasInnerGraph
38-
from pytensor.graph.traversal import graph_inputs, walk
38+
from pytensor.graph.traversal import explicit_graph_inputs, graph_inputs, walk
3939
from pytensor.scalar.basic import Cast
4040
from pytensor.scan.op import Scan
4141
from pytensor.tensor.basic import _as_tensor_variable
@@ -165,7 +165,7 @@ def extract_obs_data(x: TensorVariable) -> np.ndarray:
165165
mask[mask_idx] = 1
166166
return np.ma.MaskedArray(array_data, mask)
167167

168-
if not inputvars(x) and not rvs_in_graph(x):
168+
if not any(explicit_graph_inputs(x)) and not rvs_in_graph(x):
169169
return x.eval(mode=_cheap_eval_mode)
170170

171171
raise TypeError(f"Data cannot be extracted from {x}")
@@ -244,15 +244,17 @@ def cont_inputs(a):
244244
"""
245245
Get the continuous inputs into PyTensor variables.
246246
247+
NOTE: No particular order is guaranteed across PyTensor versions
248+
247249
Parameters
248250
----------
249251
a: PyTensor variable
250252
251253
Returns
252254
-------
253-
r: list of tensor variables that are continuous inputs
255+
r: list of tensor variables that are continuous inputs.
254256
"""
255-
return typefilter(inputvars(a), continuous_types)
257+
return typefilter(explicit_graph_inputs(a), continuous_types)
256258

257259

258260
def floatX(X):
@@ -310,6 +312,10 @@ def gradient1(f, v):
310312
def gradient(f, vars=None):
311313
if vars is None:
312314
vars = cont_inputs(f)
315+
if len(vars) > 1:
316+
raise ValueError(
317+
"gradient requires vars to be specified when there is more than one input."
318+
)
313319

314320
if vars:
315321
return pt.concatenate([gradient1(f, v) for v in vars], axis=0)
@@ -331,6 +337,10 @@ def grad_i(i):
331337
def jacobian(f, vars=None):
332338
if vars is None:
333339
vars = cont_inputs(f)
340+
if len(vars) > 1:
341+
raise ValueError(
342+
"jacobian requires vars to be specified when there is more than one input."
343+
)
334344

335345
if vars:
336346
return pt.concatenate([jacobian1(f, v) for v in vars], axis=1)
@@ -378,6 +388,10 @@ def hess_ii(i):
378388
def hessian_diag(f, vars=None, negate_output=True):
379389
if vars is None:
380390
vars = cont_inputs(f)
391+
if len(vars) > 1:
392+
raise ValueError(
393+
"hessian_diag requires vars to be specified when there is more than one input."
394+
)
381395

382396
if vars:
383397
res = pt.concatenate([hessian_diag1(f, v) for v in vars], axis=0)
@@ -611,7 +625,7 @@ def __call__(self, input):
611625
----------
612626
input: TensorVariable
613627
"""
614-
(oldinput,) = inputvars(self.tensor)
628+
(oldinput,) = explicit_graph_inputs(self.tensor)
615629
return pytensor.clone_replace(self.tensor, {oldinput: input}, rebuild_strict=False)
616630

617631

pymc/tuning/starting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def find_MAP(
185185

186186
mx0 = RaveledVars(mx0, x0.point_map_info)
187187
unobserved_vars = get_default_varnames(model.unobserved_value_vars, include_transformed)
188-
unobserved_vars_values = model.compile_fn(unobserved_vars)(
188+
unobserved_vars_values = model.compile_fn(inputs=model.value_vars, outs=unobserved_vars)(
189189
DictToArrayBijection.rmap(mx0, start)
190190
)
191191
mx = {var.name: value for var, value in zip(unobserved_vars, unobserved_vars_values)}

tests/logprob/test_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,13 @@
4242
from pytensor import tensor as pt
4343
from pytensor.compile import get_default_mode
4444
from pytensor.graph.basic import equal_computations
45-
from pytensor.graph.traversal import ancestors
45+
from pytensor.graph.traversal import ancestors, explicit_graph_inputs
4646
from pytensor.tensor.random.basic import NormalRV
4747
from pytensor.tensor.random.op import RandomVariable
4848

4949
import pymc as pm
5050

51-
from pymc import SymbolicRandomVariable, inputvars
51+
from pymc.distributions.distribution import SymbolicRandomVariable
5252
from pymc.distributions.transforms import Interval
5353
from pymc.logprob.abstract import MeasurableOp, valued_rv
5454
from pymc.logprob.basic import logp
@@ -231,7 +231,7 @@ def test_interdependent_transformed_rvs(self, reversed):
231231

232232
assert_no_rvs(transform_values)
233233
# Test that we haven't introduced value variables in the random graph (issue #7054)
234-
assert not inputvars(rvs)
234+
assert not any(explicit_graph_inputs(rvs))
235235

236236
if reversed:
237237
transform_values = transform_values[::-1]

0 commit comments

Comments
 (0)