Skip to content

Commit fa3984f

Browse files
committed
Avoid double cloning of Composite Ops created by FusionOptimizer
1 parent 7d39946 commit fa3984f

File tree

2 files changed

+19
-13
lines changed

2 files changed

+19
-13
lines changed

pytensor/scalar/basic.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import builtins
1414
import math
1515
from collections.abc import Callable
16-
from copy import copy
1716
from itertools import chain
1817
from textwrap import dedent
1918
from typing import Any, TypeAlias
@@ -4093,12 +4092,12 @@ def __init__(self, *args, **kwargs):
40934092
self.prepare_node_called = set()
40944093
super().__init__(*args, **kwargs)
40954094

4096-
def _cleanup_graph(self, inputs, outputs):
4095+
def _cleanup_graph(self, inputs, outputs, clone: builtins.bool = True):
40974096
# TODO: We could convert to TensorVariable, optimize graph,
40984097
# and then convert back to ScalarVariable.
40994098
# This would introduce rewrites like `log(1 + x) -> log1p`.
41004099

4101-
fgraph = FunctionGraph(copy(inputs), copy(outputs))
4100+
fgraph = FunctionGraph(inputs, outputs, clone=clone)
41024101

41034102
# Validate node types
41044103
for node in fgraph.apply_nodes:
@@ -4281,7 +4280,9 @@ class Composite(ScalarInnerGraphOp):
42814280

42824281
init_param: tuple[str, ...] = ("inputs", "outputs")
42834282

4284-
def __init__(self, inputs, outputs, name="Composite"):
4283+
def __init__(
4284+
self, inputs, outputs, name="Composite", clone_graph: builtins.bool = True
4285+
):
42854286
self.name = name
42864287
self._name = None
42874288
# We need to clone the graph as sometimes its nodes already
@@ -4299,10 +4300,13 @@ def __init__(self, inputs, outputs, name="Composite"):
42994300
if len(outputs) > 1 or not any(
43004301
isinstance(var.owner.op, Composite) for var in outputs
43014302
):
4302-
# No inner Composite
4303-
inputs, outputs = clone(inputs, outputs)
4303+
if clone_graph:
4304+
inputs, outputs = clone(inputs, outputs)
4305+
43044306
else:
43054307
# Inner Composite that we need to flatten
4308+
# FIXME: There could be a composite in the middle of the graph, why is this here?
4309+
# If anything it should be an optimization, but I suspect lower-level compilation can handle this anyway.
43064310
assert len(outputs) == 1
43074311
# 1. Create a new graph from inputs up to the
43084312
# Composite
@@ -4321,7 +4325,8 @@ def __init__(self, inputs, outputs, name="Composite"):
43214325
assert res[0] != inputs
43224326
inputs, outputs = res[0], res2[1]
43234327

4324-
self.inputs, self.outputs = self._cleanup_graph(inputs, outputs)
4328+
# We already cloned the graph, or the user told us there was no need for it
4329+
self.inputs, self.outputs = self._cleanup_graph(inputs, outputs, clone=False)
43254330
self.inputs_type = tuple(input.type for input in self.inputs)
43264331
self.outputs_type = tuple(output.type for output in self.outputs)
43274332
self.nin = len(inputs)

pytensor/tensor/rewriting/elemwise.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -915,12 +915,13 @@ def update_fuseable_mappings_after_fg_replace(
915915
break
916916

917917
scalar_inputs, scalar_outputs = self.elemwise_to_scalar(inputs, outputs)
918-
composite_outputs = Elemwise(ps.Composite(scalar_inputs, scalar_outputs))(
919-
*inputs
920-
)
921-
if not isinstance(composite_outputs, list):
922-
composite_outputs = [composite_outputs]
923-
for old_out, composite_out in zip(outputs, composite_outputs, strict=True):
918+
composite_outputs = Elemwise(
919+
# No need to clone Composite graph, because `self.elemwise_to_scalar` creates fresh variables
920+
ps.Composite(scalar_inputs, scalar_outputs, clone_graph=False)
921+
)(*inputs, return_list=True)
922+
assert len(outputs) == len(composite_outputs)
923+
for old_out, composite_out in zip(outputs, composite_outputs):
924+
# Preserve any names on the original outputs
924925
if old_out.name:
925926
composite_out.name = old_out.name
926927

0 commit comments

Comments
 (0)