@@ -987,25 +987,28 @@ def constant(x, name=None, dtype=None) -> ScalarConstant:
987987
988988
989989def as_scalar (x : Any , name : str | None = None ) -> ScalarVariable :
990- from pytensor .tensor .basic import scalar_from_tensor
991- from pytensor .tensor .type import TensorType
990+ if isinstance (x , ScalarVariable ):
991+ return x
992+
993+ if isinstance (x , Variable ):
994+ from pytensor .tensor .basic import scalar_from_tensor
995+ from pytensor .tensor .type import TensorType
996+
997+ if isinstance (x .type , TensorType ) and x .type .ndim == 0 :
998+ return scalar_from_tensor (x )
999+ else :
1000+ raise TypeError (f"Cannot convert { x } to a scalar type" )
9921001
9931002 if isinstance (x , Apply ):
1003+ # FIXME: Why do we support calling this with Apply?
1004+ # Also, if we do, why can't we support multiple outputs?
9941005 if len (x .outputs ) != 1 :
9951006 raise ValueError (
9961007 "It is ambiguous which output of a multi-output"
9971008 " Op has to be fetched." ,
9981009 x ,
9991010 )
1000- else :
1001- x = x .outputs [0 ]
1002- if isinstance (x , Variable ):
1003- if isinstance (x , ScalarVariable ):
1004- return x
1005- elif isinstance (x .type , TensorType ) and x .type .ndim == 0 :
1006- return scalar_from_tensor (x )
1007- else :
1008- raise TypeError (f"Cannot convert { x } to a scalar type" )
1011+ return as_scalar (x .outputs [0 ])
10091012
10101013 return constant (x )
10111014
0 commit comments