@@ -4280,7 +4280,7 @@ class Composite(ScalarInnerGraphOp):
42804280
42814281 init_param : tuple [str , ...] = ("inputs" , "outputs" )
42824282
4283- def __init__ (self , inputs , outputs , name = "Composite" ):
4283+ def __init__ (self , inputs , outputs , name = "Composite" , cleanup_graph : bool = True ):
42844284 self .name = name
42854285 self ._name = None
42864286 # We need to clone the graph as sometimes its nodes already
@@ -4299,6 +4299,7 @@ def __init__(self, inputs, outputs, name="Composite"):
42994299 isinstance (var .owner .op , Composite ) for var in outputs
43004300 ):
43014301 # No inner Composite
4302+ # FIXME: There could be a composite in the middle of the graph
43024303 inputs , outputs = clone (inputs , outputs )
43034304 else :
43044305 # Inner Composite that we need to flatten
@@ -4320,7 +4321,12 @@ def __init__(self, inputs, outputs, name="Composite"):
43204321 assert res [0 ] != inputs
43214322 inputs , outputs = res [0 ], res2 [1 ]
43224323
4323- self .inputs , self .outputs = self ._cleanup_graph (inputs , outputs , clone = False )
4324+ if cleanup_graph :
4325+ self .inputs , self .outputs = self ._cleanup_graph (
4326+ inputs , outputs , clone = False
4327+ )
4328+ else :
4329+ self .inputs , self .outputs = inputs , outputs
43244330 self .inputs_type = tuple (input .type for input in self .inputs )
43254331 self .outputs_type = tuple (output .type for output in self .outputs )
43264332 self .nin = len (inputs )
@@ -4362,11 +4368,12 @@ def make_new_inplace(self, output_types_preference=None, name=None):
43624368
43634369 """
43644370 d = {k : getattr (self , k ) for k in self .init_param }
4365- out = self . __class__ (** d )
4371+ out = type ( self ) (** d , cleanup_graph = False )
43664372 if name :
43674373 out .name = name
43684374 else :
43694375 name = out .name
4376+ out ._c_code = self ._c_code
43704377 super (Composite , out ).__init__ (output_types_preference , name )
43714378 return out
43724379
0 commit comments