Skip to content

Commit 60573b9

Browse files
committed
.fix previous composite changes
1 parent d4ab146 commit 60573b9

File tree

1 file changed

+17
-11
lines changed

1 file changed

+17
-11
lines changed

pytensor/scalar/basic.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -988,8 +988,17 @@ def constant(x, name=None, dtype=None) -> ScalarConstant:
988988

989989

990990
def as_scalar(x: Any, name: str | None = None) -> ScalarVariable:
991-
from pytensor.tensor.basic import scalar_from_tensor
992-
from pytensor.tensor.type import TensorType
991+
if isinstance(x, ScalarVariable):
992+
return x
993+
994+
if isinstance(x, Variable):
995+
from pytensor.tensor.basic import scalar_from_tensor
996+
from pytensor.tensor.type import TensorType
997+
998+
if isinstance(x.type, TensorType) and x.type.ndim == 0:
999+
return scalar_from_tensor(x)
1000+
else:
1001+
raise TypeError(f"Cannot convert {x} to a scalar type")
9931002

9941003
if isinstance(x, Apply):
9951004
if len(x.outputs) != 1:
@@ -999,14 +1008,7 @@ def as_scalar(x: Any, name: str | None = None) -> ScalarVariable:
9991008
x,
10001009
)
10011010
else:
1002-
x = x.outputs[0]
1003-
if isinstance(x, Variable):
1004-
if isinstance(x, ScalarVariable):
1005-
return x
1006-
elif isinstance(x.type, TensorType) and x.type.ndim == 0:
1007-
return scalar_from_tensor(x)
1008-
else:
1009-
raise TypeError(f"Cannot convert {x} to a scalar type")
1011+
return as_scalar(x.outputs[0])
10101012

10111013
return constant(x)
10121014

@@ -1238,7 +1240,10 @@ def make_node(self, *inputs):
12381240
f"Wrong number of inputs for {self}.make_node "
12391241
f"(got {len(inputs)}({inputs}), expected {self.nin})"
12401242
)
1241-
inputs = [as_scalar(input) for input in inputs]
1243+
inputs = [
1244+
inp if isinstance(inp, ScalarVariable) else as_scalar(input)
1245+
for inp in inputs
1246+
]
12421247
outputs = [t() for t in self.output_types([input.type for input in inputs])]
12431248
if len(outputs) != self.nout:
12441249
inputs_str = (", ".join(str(input) for input in inputs),)
@@ -4376,6 +4381,7 @@ def make_new_inplace(self, output_types_preference=None, name=None):
43764381
else:
43774382
name = out.name
43784383
out._c_code = self._c_code
4384+
out.nodenames = self.nodenames
43794385
super(Composite, out).__init__(output_types_preference, name)
43804386
return out
43814387

0 commit comments

Comments
 (0)