@@ -988,8 +988,17 @@ def constant(x, name=None, dtype=None) -> ScalarConstant:
988988
989989
990990def as_scalar (x : Any , name : str | None = None ) -> ScalarVariable :
991- from pytensor .tensor .basic import scalar_from_tensor
992- from pytensor .tensor .type import TensorType
991+ if isinstance (x , ScalarVariable ):
992+ return x
993+
994+ if isinstance (x , Variable ):
995+ from pytensor .tensor .basic import scalar_from_tensor
996+ from pytensor .tensor .type import TensorType
997+
998+ if isinstance (x .type , TensorType ) and x .type .ndim == 0 :
999+ return scalar_from_tensor (x )
1000+ else :
1001+ raise TypeError (f"Cannot convert { x } to a scalar type" )
9931002
9941003 if isinstance (x , Apply ):
9951004 if len (x .outputs ) != 1 :
@@ -999,14 +1008,7 @@ def as_scalar(x: Any, name: str | None = None) -> ScalarVariable:
9991008 x ,
10001009 )
10011010 else :
1002- x = x .outputs [0 ]
1003- if isinstance (x , Variable ):
1004- if isinstance (x , ScalarVariable ):
1005- return x
1006- elif isinstance (x .type , TensorType ) and x .type .ndim == 0 :
1007- return scalar_from_tensor (x )
1008- else :
1009- raise TypeError (f"Cannot convert { x } to a scalar type" )
1011+ return as_scalar (x .outputs [0 ])
10101012
10111013 return constant (x )
10121014
@@ -1238,7 +1240,10 @@ def make_node(self, *inputs):
12381240 f"Wrong number of inputs for { self } .make_node "
12391241 f"(got { len (inputs )} ({ inputs } ), expected { self .nin } )"
12401242 )
1241- inputs = [as_scalar (input ) for input in inputs ]
1243+ inputs = [
1244+ inp if isinstance (inp , ScalarVariable ) else as_scalar (input )
1245+ for inp in inputs
1246+ ]
12421247 outputs = [t () for t in self .output_types ([input .type for input in inputs ])]
12431248 if len (outputs ) != self .nout :
12441249 inputs_str = (", " .join (str (input ) for input in inputs ),)
@@ -4376,6 +4381,7 @@ def make_new_inplace(self, output_types_preference=None, name=None):
43764381 else :
43774382 name = out .name
43784383 out ._c_code = self ._c_code
4384+ out .nodenames = self .nodenames
43794385 super (Composite , out ).__init__ (output_types_preference , name )
43804386 return out
43814387
0 commit comments