Skip to content

Commit f9b4c0a

Browse files
committed
Avoid repeated work with Composite Ops
1 parent f3c5f0e commit f9b4c0a

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

pytensor/scalar/basic.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4280,7 +4280,7 @@ class Composite(ScalarInnerGraphOp):
42804280

42814281
init_param: tuple[str, ...] = ("inputs", "outputs")
42824282

4283-
def __init__(self, inputs, outputs, name="Composite"):
4283+
def __init__(self, inputs, outputs, name="Composite", cleanup_graph: bool = True):
42844284
self.name = name
42854285
self._name = None
42864286
# We need to clone the graph as sometimes its nodes already
@@ -4299,6 +4299,7 @@ def __init__(self, inputs, outputs, name="Composite"):
42994299
isinstance(var.owner.op, Composite) for var in outputs
43004300
):
43014301
# No inner Composite
4302+
# FIXME: There could be a composite in the middle of the graph
43024303
inputs, outputs = clone(inputs, outputs)
43034304
else:
43044305
# Inner Composite that we need to flatten
@@ -4320,7 +4321,12 @@ def __init__(self, inputs, outputs, name="Composite"):
43204321
assert res[0] != inputs
43214322
inputs, outputs = res[0], res2[1]
43224323

4323-
self.inputs, self.outputs = self._cleanup_graph(inputs, outputs, clone=False)
4324+
if cleanup_graph:
4325+
self.inputs, self.outputs = self._cleanup_graph(
4326+
inputs, outputs, clone=False
4327+
)
4328+
else:
4329+
self.inputs, self.outputs = inputs, outputs
43244330
self.inputs_type = tuple(input.type for input in self.inputs)
43254331
self.outputs_type = tuple(output.type for output in self.outputs)
43264332
self.nin = len(inputs)
@@ -4362,11 +4368,12 @@ def make_new_inplace(self, output_types_preference=None, name=None):
43624368
43634369
"""
43644370
d = {k: getattr(self, k) for k in self.init_param}
4365-
out = self.__class__(**d)
4371+
out = type(self)(**d, cleanup_graph=False)
43664372
if name:
43674373
out.name = name
43684374
else:
43694375
name = out.name
4376+
out._c_code = self._c_code
43704377
super(Composite, out).__init__(output_types_preference, name)
43714378
return out
43724379

0 commit comments

Comments
 (0)