@@ -986,25 +986,28 @@ def constant(x, name=None, dtype=None) -> ScalarConstant:
986986
987987
988988def as_scalar (x : Any , name : str | None = None ) -> ScalarVariable :
989- from pytensor .tensor .basic import scalar_from_tensor
990- from pytensor .tensor .type import TensorType
989+ if isinstance (x , ScalarVariable ):
990+ return x
991+
992+ if isinstance (x , Variable ):
993+ from pytensor .tensor .basic import scalar_from_tensor
994+ from pytensor .tensor .type import TensorType
995+
996+ if isinstance (x .type , TensorType ) and x .type .ndim == 0 :
997+ return scalar_from_tensor (x )
998+ else :
999+ raise TypeError (f"Cannot convert { x } to a scalar type" )
9911000
9921001 if isinstance (x , Apply ):
1002+ # FIXME: Why do we support calling this with Apply?
1003+ # Also, if we do, why can't we support multiple outputs?
9931004 if len (x .outputs ) != 1 :
9941005 raise ValueError (
9951006 "It is ambiguous which output of a multi-output"
9961007 " Op has to be fetched." ,
9971008 x ,
9981009 )
999- else :
1000- x = x .outputs [0 ]
1001- if isinstance (x , Variable ):
1002- if isinstance (x , ScalarVariable ):
1003- return x
1004- elif isinstance (x .type , TensorType ) and x .type .ndim == 0 :
1005- return scalar_from_tensor (x )
1006- else :
1007- raise TypeError (f"Cannot convert { x } to a scalar type" )
1010+ return as_scalar (x .outputs [0 ])
10081011
10091012 return constant (x )
10101013
0 commit comments