1313import builtins
1414import math
1515from collections .abc import Callable
16- from copy import copy
1716from itertools import chain
1817from textwrap import dedent
1918from typing import Any , TypeAlias
@@ -4094,12 +4093,12 @@ def __init__(self, *args, **kwargs):
40944093 self .prepare_node_called = set ()
40954094 super ().__init__ (* args , ** kwargs )
40964095
4097- def _cleanup_graph (self , inputs , outputs ):
4096+ def _cleanup_graph (self , inputs , outputs , clone : builtins . bool = True ):
40984097 # TODO: We could convert to TensorVariable, optimize graph,
40994098 # and then convert back to ScalarVariable.
41004099 # This would introduce rewrites like `log(1 + x) -> log1p`.
41014100
4102- fgraph = FunctionGraph (copy ( inputs ), copy ( outputs ) )
4101+ fgraph = FunctionGraph (inputs , outputs , clone = clone )
41034102
41044103 # Validate node types
41054104 for node in fgraph .apply_nodes :
@@ -4282,7 +4281,9 @@ class Composite(ScalarInnerGraphOp):
42824281
42834282 init_param : tuple [str , ...] = ("inputs" , "outputs" )
42844283
4285- def __init__ (self , inputs , outputs , name = "Composite" ):
4284+ def __init__ (
4285+ self , inputs , outputs , name = "Composite" , clone_graph : builtins .bool = True
4286+ ):
42864287 self .name = name
42874288 self ._name = None
42884289 # We need to clone the graph as sometimes its nodes already
@@ -4300,10 +4301,13 @@ def __init__(self, inputs, outputs, name="Composite"):
43004301 if len (outputs ) > 1 or not any (
43014302 isinstance (var .owner .op , Composite ) for var in outputs
43024303 ):
4303- # No inner Composite
4304- inputs , outputs = clone (inputs , outputs )
4304+ if clone_graph :
4305+ inputs , outputs = clone (inputs , outputs )
4306+
43054307 else :
43064308 # Inner Composite that we need to flatten
4309+ # FIXME: There could be a composite in the middle of the graph, why is this here?
4310+ # If anything it should be an optimization, but I suspect lower-level compilation can handle this anyway.
43074311 assert len (outputs ) == 1
43084312 # 1. Create a new graph from inputs up to the
43094313 # Composite
@@ -4322,7 +4326,8 @@ def __init__(self, inputs, outputs, name="Composite"):
43224326 assert res [0 ] != inputs
43234327 inputs , outputs = res [0 ], res2 [1 ]
43244328
4325- self .inputs , self .outputs = self ._cleanup_graph (inputs , outputs )
4329+ # We already cloned the graph, or the user told us there was no need for it
4330+ self .inputs , self .outputs = self ._cleanup_graph (inputs , outputs , clone = False )
43264331 self .inputs_type = tuple (input .type for input in self .inputs )
43274332 self .outputs_type = tuple (output .type for output in self .outputs )
43284333 self .nin = len (inputs )
0 commit comments