|
13 | 13 | import builtins |
14 | 14 | import math |
15 | 15 | from collections.abc import Callable |
16 | | -from copy import copy |
17 | 16 | from itertools import chain |
18 | 17 | from textwrap import dedent |
19 | 18 | from typing import Any, TypeAlias |
@@ -4093,12 +4092,12 @@ def __init__(self, *args, **kwargs): |
4093 | 4092 | self.prepare_node_called = set() |
4094 | 4093 | super().__init__(*args, **kwargs) |
4095 | 4094 |
|
4096 | | - def _cleanup_graph(self, inputs, outputs): |
| 4095 | + def _cleanup_graph(self, inputs, outputs, clone: bool = True): |
4097 | 4096 | # TODO: We could convert to TensorVariable, optimize graph, |
4098 | 4097 | # and then convert back to ScalarVariable. |
4099 | 4098 | # This would introduce rewrites like `log(1 + x) -> log1p`. |
4100 | 4099 |
|
4101 | | - fgraph = FunctionGraph(copy(inputs), copy(outputs)) |
| 4100 | + fgraph = FunctionGraph(inputs, outputs, clone=clone) |
4102 | 4101 |
|
4103 | 4102 | # Validate node types |
4104 | 4103 | for node in fgraph.apply_nodes: |
@@ -4321,7 +4320,7 @@ def __init__(self, inputs, outputs, name="Composite"): |
4321 | 4320 | assert res[0] != inputs |
4322 | 4321 | inputs, outputs = res[0], res2[1] |
4323 | 4322 |
|
4324 | | - self.inputs, self.outputs = self._cleanup_graph(inputs, outputs) |
| 4323 | + self.inputs, self.outputs = self._cleanup_graph(inputs, outputs, clone=False) |
4325 | 4324 | self.inputs_type = tuple(input.type for input in self.inputs) |
4326 | 4325 | self.outputs_type = tuple(output.type for output in self.outputs) |
4327 | 4326 | self.nin = len(inputs) |
@@ -4376,7 +4375,7 @@ def fgraph(self): |
4376 | 4375 | return self._fgraph |
4377 | 4376 | # fgraph cannot be a property of the base class because it messes up with C caching. |
4378 | 4377 | # We also need a `FunctionGraph(clone=True)` (default) according to an old comment |
4379 | | - fgraph = FunctionGraph(self.inputs, self.outputs) |
| 4378 | + fgraph = FunctionGraph(self.inputs, self.outputs, clone=False) |
4380 | 4379 | self._fgraph = fgraph |
4381 | 4380 | return self._fgraph |
4382 | 4381 |
|
|
0 commit comments