Skip to content

Commit 6c446c3

Browse files
committed
Avoid double cloning in ScalarInnerGraph Ops
1 parent 987a3a7 commit 6c446c3

File tree

2 files changed

+5
-6
lines changed

2 files changed

+5
-6
lines changed

pytensor/scalar/basic.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import builtins
1414
import math
1515
from collections.abc import Callable
16-
from copy import copy
1716
from itertools import chain
1817
from textwrap import dedent
1918
from typing import Any, TypeAlias
@@ -4093,12 +4092,12 @@ def __init__(self, *args, **kwargs):
40934092
self.prepare_node_called = set()
40944093
super().__init__(*args, **kwargs)
40954094

4096-
def _cleanup_graph(self, inputs, outputs):
4095+
def _cleanup_graph(self, inputs, outputs, clone: bool = True):
40974096
# TODO: We could convert to TensorVariable, optimize graph,
40984097
# and then convert back to ScalarVariable.
40994098
# This would introduce rewrites like `log(1 + x) -> log1p`.
41004099

4101-
fgraph = FunctionGraph(copy(inputs), copy(outputs))
4100+
fgraph = FunctionGraph(inputs, outputs, clone=clone)
41024101

41034102
# Validate node types
41044103
for node in fgraph.apply_nodes:
@@ -4321,7 +4320,7 @@ def __init__(self, inputs, outputs, name="Composite"):
43214320
assert res[0] != inputs
43224321
inputs, outputs = res[0], res2[1]
43234322

4324-
self.inputs, self.outputs = self._cleanup_graph(inputs, outputs)
4323+
self.inputs, self.outputs = self._cleanup_graph(inputs, outputs, clone=False)
43254324
self.inputs_type = tuple(input.type for input in self.inputs)
43264325
self.outputs_type = tuple(output.type for output in self.outputs)
43274326
self.nin = len(inputs)
@@ -4376,7 +4375,7 @@ def fgraph(self):
43764375
return self._fgraph
43774376
# fgraph cannot be a property of the base class because it messes up with C caching.
43784377
# We also need a `FunctionGraph(clone=True)` (default) according to an old comment
4379-
fgraph = FunctionGraph(self.inputs, self.outputs)
4378+
fgraph = FunctionGraph(self.inputs, self.outputs, clone=False)
43804379
self._fgraph = fgraph
43814380
return self._fgraph
43824381

pytensor/scalar/loop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def __init__(
6767
inputs, outputs = clone([*init, *constant], update)
6868

6969
self.is_while = until is not None
70-
self.inputs, self.outputs = self._cleanup_graph(inputs, outputs)
70+
self.inputs, self.outputs = self._cleanup_graph(inputs, outputs, clone=False)
7171
self._validate_updates(self.inputs, self.outputs)
7272

7373
self.inputs_type = tuple(input.type for input in self.inputs)

0 commit comments

Comments
 (0)