Skip to content

Commit 62de419

Browse files
committed
Avoid double cloning of Composite Ops created by FusionOptimizer
1 parent 0e5c760 commit 62de419

File tree

2 files changed

+19
-13
lines changed

2 files changed

+19
-13
lines changed

pytensor/scalar/basic.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import builtins
1414
import math
1515
from collections.abc import Callable
16-
from copy import copy
1716
from itertools import chain
1817
from textwrap import dedent
1918
from typing import Any, TypeAlias
@@ -4094,12 +4093,12 @@ def __init__(self, *args, **kwargs):
40944093
self.prepare_node_called = set()
40954094
super().__init__(*args, **kwargs)
40964095

4097-
def _cleanup_graph(self, inputs, outputs):
4096+
def _cleanup_graph(self, inputs, outputs, clone: builtins.bool = True):
40984097
# TODO: We could convert to TensorVariable, optimize graph,
40994098
# and then convert back to ScalarVariable.
41004099
# This would introduce rewrites like `log(1 + x) -> log1p`.
41014100

4102-
fgraph = FunctionGraph(copy(inputs), copy(outputs))
4101+
fgraph = FunctionGraph(inputs, outputs, clone=clone)
41034102

41044103
# Validate node types
41054104
for node in fgraph.apply_nodes:
@@ -4282,7 +4281,9 @@ class Composite(ScalarInnerGraphOp):
42824281

42834282
init_param: tuple[str, ...] = ("inputs", "outputs")
42844283

4285-
def __init__(self, inputs, outputs, name="Composite"):
4284+
def __init__(
4285+
self, inputs, outputs, name="Composite", clone_graph: builtins.bool = True
4286+
):
42864287
self.name = name
42874288
self._name = None
42884289
# We need to clone the graph as sometimes its nodes already
@@ -4300,10 +4301,13 @@ def __init__(self, inputs, outputs, name="Composite"):
43004301
if len(outputs) > 1 or not any(
43014302
isinstance(var.owner.op, Composite) for var in outputs
43024303
):
4303-
# No inner Composite
4304-
inputs, outputs = clone(inputs, outputs)
4304+
if clone_graph:
4305+
inputs, outputs = clone(inputs, outputs)
4306+
43054307
else:
43064308
# Inner Composite that we need to flatten
4309+
# FIXME: There could be a composite in the middle of the graph, why is this here?
4310+
# If anything it should be an optimization, but I suspect lower-level compilation can handle this anyway.
43074311
assert len(outputs) == 1
43084312
# 1. Create a new graph from inputs up to the
43094313
# Composite
@@ -4322,7 +4326,8 @@ def __init__(self, inputs, outputs, name="Composite"):
43224326
assert res[0] != inputs
43234327
inputs, outputs = res[0], res2[1]
43244328

4325-
self.inputs, self.outputs = self._cleanup_graph(inputs, outputs)
4329+
# We already cloned the graph, or the user told us there was no need for it
4330+
self.inputs, self.outputs = self._cleanup_graph(inputs, outputs, clone=False)
43264331
self.inputs_type = tuple(input.type for input in self.inputs)
43274332
self.outputs_type = tuple(output.type for output in self.outputs)
43284333
self.nin = len(inputs)

pytensor/tensor/rewriting/elemwise.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -915,12 +915,13 @@ def update_fuseable_mappings_after_fg_replace(
915915
break
916916

917917
scalar_inputs, scalar_outputs = self.elemwise_to_scalar(inputs, outputs)
918-
composite_outputs = Elemwise(ps.Composite(scalar_inputs, scalar_outputs))(
919-
*inputs
920-
)
921-
if not isinstance(composite_outputs, list):
922-
composite_outputs = [composite_outputs]
923-
for old_out, composite_out in zip(outputs, composite_outputs, strict=True):
918+
composite_outputs = Elemwise(
919+
# No need to clone Composite graph, because `self.elemwise_to_scalar` creates fresh variables
920+
ps.Composite(scalar_inputs, scalar_outputs, clone_graph=False)
921+
)(*inputs, return_list=True)
922+
assert len(outputs) == len(composite_outputs)
923+
for old_out, composite_out in zip(outputs, composite_outputs):
924+
# Preserve any names on the original outputs
924925
if old_out.name:
925926
composite_out.name = old_out.name
926927

0 commit comments

Comments
 (0)