@@ -4287,7 +4287,14 @@ class Composite(ScalarInnerGraphOp):
42874287
42884288 init_param : tuple [str , ...] = ("inputs" , "outputs" )
42894289
4290- def __init__ (self , inputs , outputs , name = "Composite" , cleanup_graph : bool = True ):
4290+ def __init__ (
4291+ self ,
4292+ inputs ,
4293+ outputs ,
4294+ name = "Composite" ,
4295+ clone_graph : bool = True ,
4296+ output_types_preference = None ,
4297+ ):
42914298 self .name = name
42924299 self ._name = None
42934300 # We need to clone the graph as sometimes its nodes already
@@ -4302,33 +4309,35 @@ def __init__(self, inputs, outputs, name="Composite", cleanup_graph: bool = True
43024309 for i in inputs :
43034310 assert i not in outputs # This isn't supported, use identity
43044311
4305- if len (outputs ) > 1 or not any (
4306- isinstance (var .owner .op , Composite ) for var in outputs
4307- ):
4308- # No inner Composite
4309- # FIXME: There could be a composite in the middle of the graph
4310- inputs , outputs = clone (inputs , outputs )
4311- else :
4312- # Inner Composite that we need to flatten
4313- assert len (outputs ) == 1
4314- # 1. Create a new graph from inputs up to the
4315- # Composite
4316- res = pytensor .compile .rebuild_collect_shared (
4317- inputs = inputs , outputs = outputs [0 ].owner .inputs , copy_inputs_over = False
4318- ) # Clone also the inputs
4319- # 2. We continue this partial clone with the graph in
4320- # the inner Composite
4321- res2 = pytensor .compile .rebuild_collect_shared (
4322- inputs = outputs [0 ].owner .op .inputs ,
4323- outputs = outputs [0 ].owner .op .outputs ,
4324- replace = dict (zip (outputs [0 ].owner .op .inputs , res [1 ], strict = True )),
4325- )
4326- assert len (res2 [1 ]) == len (outputs )
4327- assert len (res [0 ]) == len (inputs )
4328- assert res [0 ] != inputs
4329- inputs , outputs = res [0 ], res2 [1 ]
4312+ if clone_graph :
4313+ if len (outputs ) > 1 or not any (
4314+ isinstance (var .owner .op , Composite ) for var in outputs
4315+ ):
4316+ # No inner Composite
4317+ # FIXME: There could be a composite in the middle of the graph
4318+ inputs , outputs = clone (inputs , outputs )
4319+ else :
4320+ # Inner Composite that we need to flatten
4321+ assert len (outputs ) == 1
4322+ # 1. Create a new graph from inputs up to the
4323+ # Composite
4324+ res = pytensor .compile .rebuild_collect_shared (
4325+ inputs = inputs ,
4326+ outputs = outputs [0 ].owner .inputs ,
4327+ copy_inputs_over = False ,
4328+ ) # Clone also the inputs
4329+ # 2. We continue this partial clone with the graph in
4330+ # the inner Composite
4331+ res2 = pytensor .compile .rebuild_collect_shared (
4332+ inputs = outputs [0 ].owner .op .inputs ,
4333+ outputs = outputs [0 ].owner .op .outputs ,
4334+ replace = dict (zip (outputs [0 ].owner .op .inputs , res [1 ], strict = True )),
4335+ )
4336+ assert len (res2 [1 ]) == len (outputs )
4337+ assert len (res [0 ]) == len (inputs )
4338+ assert res [0 ] != inputs
4339+ inputs , outputs = res [0 ], res2 [1 ]
43304340
4331- if cleanup_graph :
43324341 self .inputs , self .outputs = self ._cleanup_graph (
43334342 inputs , outputs , clone = False
43344343 )
@@ -4338,7 +4347,7 @@ def __init__(self, inputs, outputs, name="Composite", cleanup_graph: bool = True
43384347 self .outputs_type = tuple (output .type for output in self .outputs )
43394348 self .nin = len (inputs )
43404349 self .nout = len (outputs )
4341- super ().__init__ ()
4350+ super ().__init__ (output_types_preference = output_types_preference )
43424351
43434352 def __str__ (self ):
43444353 if self ._name is not None :
@@ -4374,15 +4383,17 @@ def make_new_inplace(self, output_types_preference=None, name=None):
43744383 This fct allow fix patch this.
43754384
43764385 """
4386+
43774387 d = {k : getattr (self , k ) for k in self .init_param }
4378- out = type (self )(** d , cleanup_graph = False )
4379- if name :
4380- out .name = name
4381- else :
4382- name = out .name
4388+ out = type (self )(
4389+ ** d ,
4390+ clone_graph = False ,
4391+ output_types_preference = output_types_preference ,
4392+ name = name or self .name ,
4393+ )
4394+ # No need to recompute the _cocde and nodenames if they were already computed (which is true if the hash of the Op was requested)
43834395 out ._c_code = self ._c_code
43844396 out .nodenames = self .nodenames
4385- super (Composite , out ).__init__ (output_types_preference , name )
43864397 return out
43874398
43884399 @property
0 commit comments