Skip to content

Commit 3e9901a

Browse files
committed
Avoid creating FunctionGraphs for Composite rewrite
1 parent fecb508 commit 3e9901a

File tree

2 files changed

+59
-39
lines changed

2 files changed

+59
-39
lines changed

pytensor/scalar/basic.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4331,24 +4331,25 @@ def __str__(self):
43314331
if self._name is not None:
43324332
return self._name
43334333

4334-
# Rename internal variables
4335-
for i, r in enumerate(self.fgraph.inputs):
4336-
r.name = f"i{i}"
4337-
for i, r in enumerate(self.fgraph.outputs):
4338-
r.name = f"o{i}"
4339-
io = set(self.fgraph.inputs + self.fgraph.outputs)
4340-
for i, r in enumerate(self.fgraph.variables):
4341-
if (
4342-
not isinstance(r, Constant)
4343-
and r not in io
4344-
and len(self.fgraph.clients[r]) > 1
4345-
):
4346-
r.name = f"t{i}"
4334+
fgraph = self.fgraph
43474335

4348-
if len(self.fgraph.outputs) > 1 or len(self.fgraph.apply_nodes) > 10:
4336+
if len(fgraph.outputs) > 1 or len(fgraph.apply_nodes) > 10:
43494337
self._name = "Composite{...}"
43504338
else:
4351-
outputs_str = ", ".join(pprint(output) for output in self.fgraph.outputs)
4339+
# Rename internal variables
4340+
for i, r in enumerate(fgraph.inputs):
4341+
r.name = f"i{i}"
4342+
for i, r in enumerate(fgraph.outputs):
4343+
r.name = f"o{i}"
4344+
io = set(fgraph.inputs + fgraph.outputs)
4345+
for i, r in enumerate(fgraph.variables):
4346+
if (
4347+
not isinstance(r, Constant)
4348+
and r not in io
4349+
and len(fgraph.clients[r]) > 1
4350+
):
4351+
r.name = f"t{i}"
4352+
outputs_str = ", ".join(pprint(output) for output in fgraph.outputs)
43524353
self._name = f"Composite{{{outputs_str}}}"
43534354

43544355
return self._name
@@ -4433,9 +4434,10 @@ def c_code_template(self):
44334434
fg = self.fgraph
44344435
subd = {e: f"%(i{i})s" for i, e in enumerate(fg.inputs)}
44354436

4437+
inputs_set = frozenset(fg.inputs)
44364438
for var in fg.variables:
44374439
if var.owner is None:
4438-
if var not in fg.inputs:
4440+
if var not in inputs_set:
44394441
# This is an orphan
44404442
if isinstance(var, Constant) and isinstance(var.type, CLinkerType):
44414443
subd[var] = f"({var.type.c_literal(var.data)})"

pytensor/tensor/rewriting/elemwise.py

Lines changed: 41 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,9 @@
2727
out2in,
2828
)
2929
from pytensor.graph.rewriting.db import SequenceDB
30-
from pytensor.graph.traversal import ancestors, toposort
30+
from pytensor.graph.traversal import ancestors, graph_inputs, toposort
3131
from pytensor.graph.utils import InconsistencyError, MethodNotDefined
32+
from pytensor.scalar import ScalarConstant
3233
from pytensor.scalar.math import Grad2F1Loop, _grad_2f1_loop
3334
from pytensor.tensor.basic import (
3435
MakeVector,
@@ -1015,31 +1016,48 @@ def print_profile(stream, prof, level=0):
10151016
@node_rewriter([Elemwise])
10161017
def local_useless_composite_outputs(fgraph, node):
10171018
"""Remove inputs and outputs of Composite Ops that are not used anywhere."""
1018-
if not (
1019-
isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, ps.Composite)
1020-
):
1021-
return
10221019
comp = node.op.scalar_op
1023-
used_outputs_idxs = [
1024-
i for i, o_extern in enumerate(node.outputs) if fgraph.clients[o_extern]
1025-
]
1026-
used_inner_outputs = [comp.outputs[i] for i in used_outputs_idxs]
1027-
comp_fgraph = FunctionGraph(
1028-
inputs=comp.inputs, outputs=used_inner_outputs, clone=False
1029-
)
1020+
1021+
if not isinstance(node.op.scalar_op, ps.Composite):
1022+
return None
1023+
1024+
clients = fgraph.clients
1025+
outer_inputs, outer_outputs = node.inputs, node.outputs
1026+
inner_inputs, inner_outputs = comp.inputs, comp.outputs
1027+
1028+
used_inner_outputs = {
1029+
inner_out
1030+
for inner_out, outer_out in zip(inner_outputs, outer_outputs)
1031+
if clients[outer_out]
1032+
}
1033+
used_inner_inputs = {
1034+
inner_inp
1035+
for inner_inp in graph_inputs(used_inner_outputs)
1036+
if not isinstance(inner_inp, ScalarConstant)
1037+
}
1038+
1039+
if len(used_inner_inputs) == len(outer_inputs) or len(used_inner_outputs) == len(
1040+
outer_outputs
1041+
):
1042+
return None
1043+
10301044
used_inputs_idxs = [
1031-
i
1032-
for i, i_intern in enumerate(comp_fgraph.inputs)
1033-
if comp_fgraph.clients[i_intern]
1045+
i for i, inp in enumerate(inner_inputs) if inp in used_inner_inputs
10341046
]
1035-
used_inner_inputs = [comp.inputs[i] for i in used_inputs_idxs]
1036-
if len(used_inner_inputs) < len(node.inputs) or len(used_inner_outputs) < len(
1037-
node.outputs
1038-
):
1039-
used_inputs = [node.inputs[i] for i in used_inputs_idxs]
1040-
c = ps.Composite(inputs=used_inner_inputs, outputs=used_inner_outputs)
1041-
e = Elemwise(scalar_op=c)(*used_inputs, return_list=True)
1042-
return dict(zip([node.outputs[i] for i in used_outputs_idxs], e, strict=True))
1047+
used_inner_inputs = [inner_inputs[i] for i in used_inputs_idxs]
1048+
used_outer_inputs = [outer_inputs[i] for i in used_inputs_idxs]
1049+
1050+
new_comp = ps.Composite(inputs=used_inner_inputs, outputs=used_inner_outputs)
1051+
new_outer_outputs = Elemwise(scalar_op=new_comp)(
1052+
*used_outer_inputs, return_list=True
1053+
)
1054+
1055+
used_outer_outputs = (
1056+
outer_outputs[i]
1057+
for i, out in enumerate(inner_outputs)
1058+
if out in used_inner_outputs
1059+
)
1060+
return dict(zip(used_outer_outputs, new_outer_outputs, strict=True))
10431061

10441062

10451063
@node_rewriter([CAReduce])

0 commit comments

Comments
 (0)