Skip to content

Commit 1cd19d4

Browse files
committed
.less composite copies
1 parent 078c9ba commit 1cd19d4

File tree

1 file changed

+45
-34
lines changed

1 file changed

+45
-34
lines changed

pytensor/scalar/basic.py

Lines changed: 45 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -4287,7 +4287,14 @@ class Composite(ScalarInnerGraphOp):
42874287

42884288
init_param: tuple[str, ...] = ("inputs", "outputs")
42894289

4290-
def __init__(self, inputs, outputs, name="Composite", cleanup_graph: bool = True):
4290+
def __init__(
4291+
self,
4292+
inputs,
4293+
outputs,
4294+
name="Composite",
4295+
clone_graph: bool = True,
4296+
output_types_preference=None,
4297+
):
42914298
self.name = name
42924299
self._name = None
42934300
# We need to clone the graph as sometimes its nodes already
@@ -4302,33 +4309,35 @@ def __init__(self, inputs, outputs, name="Composite", cleanup_graph: bool = True
43024309
for i in inputs:
43034310
assert i not in outputs # This isn't supported, use identity
43044311

4305-
if len(outputs) > 1 or not any(
4306-
isinstance(var.owner.op, Composite) for var in outputs
4307-
):
4308-
# No inner Composite
4309-
# FIXME: There could be a composite in the middle of the graph
4310-
inputs, outputs = clone(inputs, outputs)
4311-
else:
4312-
# Inner Composite that we need to flatten
4313-
assert len(outputs) == 1
4314-
# 1. Create a new graph from inputs up to the
4315-
# Composite
4316-
res = pytensor.compile.rebuild_collect_shared(
4317-
inputs=inputs, outputs=outputs[0].owner.inputs, copy_inputs_over=False
4318-
) # Clone also the inputs
4319-
# 2. We continue this partial clone with the graph in
4320-
# the inner Composite
4321-
res2 = pytensor.compile.rebuild_collect_shared(
4322-
inputs=outputs[0].owner.op.inputs,
4323-
outputs=outputs[0].owner.op.outputs,
4324-
replace=dict(zip(outputs[0].owner.op.inputs, res[1], strict=True)),
4325-
)
4326-
assert len(res2[1]) == len(outputs)
4327-
assert len(res[0]) == len(inputs)
4328-
assert res[0] != inputs
4329-
inputs, outputs = res[0], res2[1]
4312+
if clone_graph:
4313+
if len(outputs) > 1 or not any(
4314+
isinstance(var.owner.op, Composite) for var in outputs
4315+
):
4316+
# No inner Composite
4317+
# FIXME: There could be a composite in the middle of the graph
4318+
inputs, outputs = clone(inputs, outputs)
4319+
else:
4320+
# Inner Composite that we need to flatten
4321+
assert len(outputs) == 1
4322+
# 1. Create a new graph from inputs up to the
4323+
# Composite
4324+
res = pytensor.compile.rebuild_collect_shared(
4325+
inputs=inputs,
4326+
outputs=outputs[0].owner.inputs,
4327+
copy_inputs_over=False,
4328+
) # Clone also the inputs
4329+
# 2. We continue this partial clone with the graph in
4330+
# the inner Composite
4331+
res2 = pytensor.compile.rebuild_collect_shared(
4332+
inputs=outputs[0].owner.op.inputs,
4333+
outputs=outputs[0].owner.op.outputs,
4334+
replace=dict(zip(outputs[0].owner.op.inputs, res[1], strict=True)),
4335+
)
4336+
assert len(res2[1]) == len(outputs)
4337+
assert len(res[0]) == len(inputs)
4338+
assert res[0] != inputs
4339+
inputs, outputs = res[0], res2[1]
43304340

4331-
if cleanup_graph:
43324341
self.inputs, self.outputs = self._cleanup_graph(
43334342
inputs, outputs, clone=False
43344343
)
@@ -4338,7 +4347,7 @@ def __init__(self, inputs, outputs, name="Composite", cleanup_graph: bool = True
43384347
self.outputs_type = tuple(output.type for output in self.outputs)
43394348
self.nin = len(inputs)
43404349
self.nout = len(outputs)
4341-
super().__init__()
4350+
super().__init__(output_types_preference=output_types_preference)
43424351

43434352
def __str__(self):
43444353
if self._name is not None:
@@ -4374,15 +4383,17 @@ def make_new_inplace(self, output_types_preference=None, name=None):
43744383
This fct allow fix patch this.
43754384
43764385
"""
4386+
43774387
d = {k: getattr(self, k) for k in self.init_param}
4378-
out = type(self)(**d, cleanup_graph=False)
4379-
if name:
4380-
out.name = name
4381-
else:
4382-
name = out.name
4388+
out = type(self)(
4389+
**d,
4390+
clone_graph=False,
4391+
output_types_preference=output_types_preference,
4392+
name=name or self.name,
4393+
)
4394+
# No need to recompute the _cocde and nodenames if they were already computed (which is true if the hash of the Op was requested)
43834395
out._c_code = self._c_code
43844396
out.nodenames = self.nodenames
4385-
super(Composite, out).__init__(output_types_preference, name)
43864397
return out
43874398

43884399
@property

0 commit comments

Comments
 (0)