1313import builtins
1414import math
1515from collections .abc import Callable
16- from copy import copy
1716from itertools import chain
1817from textwrap import dedent
1918from typing import Any , TypeAlias
@@ -4093,12 +4092,12 @@ def __init__(self, *args, **kwargs):
40934092 self .prepare_node_called = set ()
40944093 super ().__init__ (* args , ** kwargs )
40954094
4096- def _cleanup_graph (self , inputs , outputs ):
4095+ def _cleanup_graph (self , inputs , outputs , clone : bool = True ):
40974096 # TODO: We could convert to TensorVariable, optimize graph,
40984097 # and then convert back to ScalarVariable.
40994098 # This would introduce rewrites like `log(1 + x) -> log1p`.
41004099
4101- fgraph = FunctionGraph (copy ( inputs ), copy ( outputs ) )
4100+ fgraph = FunctionGraph (inputs , outputs , clone = clone )
41024101
41034102 # Validate node types
41044103 for node in fgraph .apply_nodes :
@@ -4281,7 +4280,7 @@ class Composite(ScalarInnerGraphOp):
42814280
42824281 init_param : tuple [str , ...] = ("inputs" , "outputs" )
42834282
4284- def __init__ (self , inputs , outputs , name = "Composite" ):
4283+ def __init__ (self , inputs , outputs , name = "Composite" , clone_graph : bool = True ):
42854284 self .name = name
42864285 self ._name = None
42874286 # We need to clone the graph as sometimes its nodes already
@@ -4299,10 +4298,13 @@ def __init__(self, inputs, outputs, name="Composite"):
42994298 if len (outputs ) > 1 or not any (
43004299 isinstance (var .owner .op , Composite ) for var in outputs
43014300 ):
4302- # No inner Composite
4303- inputs , outputs = clone (inputs , outputs )
4301+ if clone_graph :
4302+ inputs , outputs = clone (inputs , outputs )
4303+
43044304 else :
43054305 # Inner Composite that we need to flatten
4306+ # FIXME: There could be a composite in the middle of the graph, why is this here?
4307+ # If anything it should be an optimization, but I suspect lower-level compilation can handle this anyway.
43064308 assert len (outputs ) == 1
43074309 # 1. Create a new graph from inputs up to the
43084310 # Composite
@@ -4321,7 +4323,8 @@ def __init__(self, inputs, outputs, name="Composite"):
43214323 assert res [0 ] != inputs
43224324 inputs , outputs = res [0 ], res2 [1 ]
43234325
4324- self .inputs , self .outputs = self ._cleanup_graph (inputs , outputs )
4326+ # We already cloned the graph, or the user told us there was no need for it
4327+ self .inputs , self .outputs = self ._cleanup_graph (inputs , outputs , clone = False )
43254328 self .inputs_type = tuple (input .type for input in self .inputs )
43264329 self .outputs_type = tuple (output .type for output in self .outputs )
43274330 self .nin = len (inputs )
0 commit comments