1313import builtins
1414import math
1515from collections .abc import Callable
16- from copy import copy
1716from itertools import chain
1817from textwrap import dedent
1918from typing import Any , TypeAlias
@@ -4093,12 +4092,12 @@ def __init__(self, *args, **kwargs):
40934092 self .prepare_node_called = set ()
40944093 super ().__init__ (* args , ** kwargs )
40954094
4096- def _cleanup_graph (self , inputs , outputs ):
4095+ def _cleanup_graph (self , inputs , outputs , clone : builtins . bool = True ):
40974096 # TODO: We could convert to TensorVariable, optimize graph,
40984097 # and then convert back to ScalarVariable.
40994098 # This would introduce rewrites like `log(1 + x) -> log1p`.
41004099
4101- fgraph = FunctionGraph (copy ( inputs ), copy ( outputs ) )
4100+ fgraph = FunctionGraph (inputs , outputs , clone = clone )
41024101
41034102 # Validate node types
41044103 for node in fgraph .apply_nodes :
@@ -4281,7 +4280,9 @@ class Composite(ScalarInnerGraphOp):
42814280
42824281 init_param : tuple [str , ...] = ("inputs" , "outputs" )
42834282
4284- def __init__ (self , inputs , outputs , name = "Composite" ):
4283+ def __init__ (
4284+ self , inputs , outputs , name = "Composite" , clone_graph : builtins .bool = True
4285+ ):
42854286 self .name = name
42864287 self ._name = None
42874288 # We need to clone the graph as sometimes its nodes already
@@ -4299,10 +4300,13 @@ def __init__(self, inputs, outputs, name="Composite"):
42994300 if len (outputs ) > 1 or not any (
43004301 isinstance (var .owner .op , Composite ) for var in outputs
43014302 ):
4302- # No inner Composite
4303- inputs , outputs = clone (inputs , outputs )
4303+ if clone_graph :
4304+ inputs , outputs = clone (inputs , outputs )
4305+
43044306 else :
43054307 # Inner Composite that we need to flatten
4308+ # FIXME: There could be a composite in the middle of the graph, why is this here?
4309+ # If anything it should be an optimization, but I suspect lower-level compilation can handle this anyway.
43064310 assert len (outputs ) == 1
43074311 # 1. Create a new graph from inputs up to the
43084312 # Composite
@@ -4321,7 +4325,8 @@ def __init__(self, inputs, outputs, name="Composite"):
43214325 assert res [0 ] != inputs
43224326 inputs , outputs = res [0 ], res2 [1 ]
43234327
4324- self .inputs , self .outputs = self ._cleanup_graph (inputs , outputs )
4328+ # We already cloned the graph, or the user told us there was no need for it
4329+ self .inputs , self .outputs = self ._cleanup_graph (inputs , outputs , clone = False )
43254330 self .inputs_type = tuple (input .type for input in self .inputs )
43264331 self .outputs_type = tuple (output .type for output in self .outputs )
43274332 self .nin = len (inputs )
0 commit comments