|
27 | 27 | out2in, |
28 | 28 | ) |
29 | 29 | from pytensor.graph.rewriting.db import SequenceDB |
30 | | -from pytensor.graph.traversal import ancestors, toposort |
| 30 | +from pytensor.graph.traversal import ancestors, graph_inputs, toposort |
31 | 31 | from pytensor.graph.utils import InconsistencyError, MethodNotDefined |
| 32 | +from pytensor.scalar import ScalarConstant |
32 | 33 | from pytensor.scalar.math import Grad2F1Loop, _grad_2f1_loop |
33 | 34 | from pytensor.tensor.basic import ( |
34 | 35 | MakeVector, |
@@ -1015,31 +1016,48 @@ def print_profile(stream, prof, level=0): |
1015 | 1016 | @node_rewriter([Elemwise]) |
1016 | 1017 | def local_useless_composite_outputs(fgraph, node): |
1017 | 1018 | """Remove inputs and outputs of Composite Ops that are not used anywhere.""" |
1018 | | - if not ( |
1019 | | - isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, ps.Composite) |
1020 | | - ): |
1021 | | - return |
1022 | 1019 | comp = node.op.scalar_op |
1023 | | - used_outputs_idxs = [ |
1024 | | - i for i, o_extern in enumerate(node.outputs) if fgraph.clients[o_extern] |
1025 | | - ] |
1026 | | - used_inner_outputs = [comp.outputs[i] for i in used_outputs_idxs] |
1027 | | - comp_fgraph = FunctionGraph( |
1028 | | - inputs=comp.inputs, outputs=used_inner_outputs, clone=False |
1029 | | - ) |
| 1020 | + |
| 1021 | + if not isinstance(node.op.scalar_op, ps.Composite): |
| 1022 | + return None |
| 1023 | + |
| 1024 | + clients = fgraph.clients |
| 1025 | + outer_inputs, outer_outputs = node.inputs, node.outputs |
| 1026 | + inner_inputs, inner_outputs = comp.inputs, comp.outputs |
| 1027 | + |
| 1028 | + used_inner_outputs = { |
| 1029 | + inner_out |
| 1030 | + for inner_out, outer_out in zip(inner_outputs, outer_outputs) |
| 1031 | + if clients[outer_out] |
| 1032 | + } |
| 1033 | + used_inner_inputs = { |
| 1034 | + inner_inp |
| 1035 | + for inner_inp in graph_inputs(used_inner_outputs) |
| 1036 | + if not isinstance(inner_inp, ScalarConstant) |
| 1037 | + } |
| 1038 | + |
| 1039 | + if len(used_inner_inputs) == len(outer_inputs) or len(used_inner_outputs) == len( |
| 1040 | + outer_outputs |
| 1041 | + ): |
| 1042 | + return None |
| 1043 | + |
1030 | 1044 | used_inputs_idxs = [ |
1031 | | - i |
1032 | | - for i, i_intern in enumerate(comp_fgraph.inputs) |
1033 | | - if comp_fgraph.clients[i_intern] |
| 1045 | + i for i, inp in enumerate(inner_inputs) if inp in used_inner_inputs |
1034 | 1046 | ] |
1035 | | - used_inner_inputs = [comp.inputs[i] for i in used_inputs_idxs] |
1036 | | - if len(used_inner_inputs) < len(node.inputs) or len(used_inner_outputs) < len( |
1037 | | - node.outputs |
1038 | | - ): |
1039 | | - used_inputs = [node.inputs[i] for i in used_inputs_idxs] |
1040 | | - c = ps.Composite(inputs=used_inner_inputs, outputs=used_inner_outputs) |
1041 | | - e = Elemwise(scalar_op=c)(*used_inputs, return_list=True) |
1042 | | - return dict(zip([node.outputs[i] for i in used_outputs_idxs], e, strict=True)) |
| 1047 | + used_inner_inputs = [inner_inputs[i] for i in used_inputs_idxs] |
| 1048 | + used_outer_inputs = [outer_inputs[i] for i in used_inputs_idxs] |
| 1049 | + |
| 1050 | + new_comp = ps.Composite(inputs=used_inner_inputs, outputs=used_inner_outputs) |
| 1051 | + new_outer_outputs = Elemwise(scalar_op=new_comp)( |
| 1052 | + *used_outer_inputs, return_list=True |
| 1053 | + ) |
| 1054 | + |
| 1055 | + used_outer_outputs = ( |
| 1056 | + outer_outputs[i] |
| 1057 | + for i, out in enumerate(inner_outputs) |
| 1058 | + if out in used_inner_outputs |
| 1059 | + ) |
| 1060 | + return dict(zip(used_outer_outputs, new_outer_outputs, strict=True)) |
1043 | 1061 |
|
1044 | 1062 |
|
1045 | 1063 | @node_rewriter([CAReduce]) |
|
0 commit comments