Skip to content

Commit 0b2cf52

Browse files
committed
Short-circuit as_scalar common cases faster
1 parent 5f8cee6 commit 0b2cf52

File tree

1 file changed

+14
-11
lines changed

1 file changed

+14
-11
lines changed

pytensor/scalar/basic.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -987,25 +987,28 @@ def constant(x, name=None, dtype=None) -> ScalarConstant:
987987

988988

989989
def as_scalar(x: Any, name: str | None = None) -> ScalarVariable:
990-
from pytensor.tensor.basic import scalar_from_tensor
991-
from pytensor.tensor.type import TensorType
990+
if isinstance(x, ScalarVariable):
991+
return x
992+
993+
if isinstance(x, Variable):
994+
from pytensor.tensor.basic import scalar_from_tensor
995+
from pytensor.tensor.type import TensorType
996+
997+
if isinstance(x.type, TensorType) and x.type.ndim == 0:
998+
return scalar_from_tensor(x)
999+
else:
1000+
raise TypeError(f"Cannot convert {x} to a scalar type")
9921001

9931002
if isinstance(x, Apply):
1003+
# FIXME: Why do we support calling this with Apply?
1004+
# Also, if we do, why can't we support multiple outputs?
9941005
if len(x.outputs) != 1:
9951006
raise ValueError(
9961007
"It is ambiguous which output of a multi-output"
9971008
" Op has to be fetched.",
9981009
x,
9991010
)
1000-
else:
1001-
x = x.outputs[0]
1002-
if isinstance(x, Variable):
1003-
if isinstance(x, ScalarVariable):
1004-
return x
1005-
elif isinstance(x.type, TensorType) and x.type.ndim == 0:
1006-
return scalar_from_tensor(x)
1007-
else:
1008-
raise TypeError(f"Cannot convert {x} to a scalar type")
1011+
return as_scalar(x.outputs[0])
10091012

10101013
return constant(x)
10111014

0 commit comments

Comments
 (0)