Skip to content

Commit 87e84af

Browse files
committed
Short-circuit as_scalar common cases faster
1 parent 2f499c4 commit 87e84af

File tree

1 file changed

+14
-11
lines changed

1 file changed

+14
-11
lines changed

pytensor/scalar/basic.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -986,25 +986,28 @@ def constant(x, name=None, dtype=None) -> ScalarConstant:
986986

987987

988988
def as_scalar(x: Any, name: str | None = None) -> ScalarVariable:
989-
from pytensor.tensor.basic import scalar_from_tensor
990-
from pytensor.tensor.type import TensorType
989+
if isinstance(x, ScalarVariable):
990+
return x
991+
992+
if isinstance(x, Variable):
993+
from pytensor.tensor.basic import scalar_from_tensor
994+
from pytensor.tensor.type import TensorType
995+
996+
if isinstance(x.type, TensorType) and x.type.ndim == 0:
997+
return scalar_from_tensor(x)
998+
else:
999+
raise TypeError(f"Cannot convert {x} to a scalar type")
9911000

9921001
if isinstance(x, Apply):
1002+
# FIXME: Why do we support calling this with Apply?
1003+
# Also, if we do, why can't we support multiple outputs?
9931004
if len(x.outputs) != 1:
9941005
raise ValueError(
9951006
"It is ambiguous which output of a multi-output"
9961007
" Op has to be fetched.",
9971008
x,
9981009
)
999-
else:
1000-
x = x.outputs[0]
1001-
if isinstance(x, Variable):
1002-
if isinstance(x, ScalarVariable):
1003-
return x
1004-
elif isinstance(x.type, TensorType) and x.type.ndim == 0:
1005-
return scalar_from_tensor(x)
1006-
else:
1007-
raise TypeError(f"Cannot convert {x} to a scalar type")
1010+
return as_scalar(x.outputs[0])
10081011

10091012
return constant(x)
10101013

0 commit comments

Comments
 (0)