|
28 | 28 | ) |
29 | 29 | from pytensor.graph.rewriting.db import SequenceDB |
30 | 30 | from pytensor.graph.rewriting.unify import OpPattern |
31 | | -from pytensor.graph.traversal import ancestors |
| 31 | +from pytensor.graph.traversal import ancestors, toposort |
32 | 32 | from pytensor.graph.utils import InconsistencyError, MethodNotDefined |
33 | 33 | from pytensor.scalar.math import Grad2F1Loop, _grad_2f1_loop |
34 | 34 | from pytensor.tensor.basic import ( |
@@ -530,43 +530,24 @@ def add_requirements(self, fgraph): |
530 | 530 |
|
531 | 531 | @staticmethod |
532 | 532 | def elemwise_to_scalar(inputs, outputs): |
533 | | - replace_inputs = [(inp, inp.clone()) for inp in inputs] |
534 | | - outputs = clone_replace(outputs, replace=replace_inputs) |
535 | | - |
536 | | - inputs = [inp for _, inp in replace_inputs] |
537 | | - fg = FunctionGraph(inputs=inputs, outputs=outputs, clone=False) |
538 | | - middle_inputs = [] |
539 | | - |
540 | | - scalar_inputs = [ |
541 | | - ps.get_scalar_type(inp.type.dtype).make_variable() for inp in inputs |
542 | | - ] |
543 | | - middle_scalar_inputs = [] |
544 | | - |
545 | | - for node in fg.toposort(): |
546 | | - node_scalar_inputs = [] |
547 | | - for inp in node.inputs: |
548 | | - if inp in inputs: |
549 | | - node_scalar_inputs.append(scalar_inputs[inputs.index(inp)]) |
550 | | - elif inp in middle_inputs: |
551 | | - node_scalar_inputs.append( |
552 | | - middle_scalar_inputs[middle_inputs.index(inp)] |
| 533 | + replacement = { |
| 534 | + inp: ps.get_scalar_type(inp.type.dtype).make_variable() for inp in inputs |
| 535 | + } |
| 536 | + for node in toposort(outputs, blockers=inputs): |
| 537 | + scalar_inputs = [replacement[inp] for inp in node.inputs] |
| 538 | + replacement.update( |
| 539 | + dict( |
| 540 | + zip( |
| 541 | + node.outputs, |
| 542 | + node.op.scalar_op.make_node(*scalar_inputs).outputs, |
553 | 543 | ) |
554 | | - else: |
555 | | - new_scalar_input = ps.get_scalar_type( |
556 | | - inp.type.dtype |
557 | | - ).make_variable() |
558 | | - node_scalar_inputs.append(new_scalar_input) |
559 | | - middle_scalar_inputs.append(new_scalar_input) |
560 | | - middle_inputs.append(inp) |
561 | | - |
562 | | - new_scalar_node = node.op.scalar_op.make_node(*node_scalar_inputs) |
563 | | - middle_scalar_inputs.append(new_scalar_node.outputs[0]) |
564 | | - middle_inputs.append(node.outputs[0]) |
565 | | - |
566 | | - scalar_outputs = [ |
567 | | - middle_scalar_inputs[middle_inputs.index(out)] for out in fg.outputs |
568 | | - ] |
569 | | - return scalar_inputs, scalar_outputs |
| 544 | + ) |
| 545 | + ) |
| 546 | + |
| 547 | + return ( |
| 548 | + [replacement[inp] for inp in inputs], |
| 549 | + [replacement[out] for out in outputs], |
| 550 | + ) |
570 | 551 |
|
571 | 552 | def apply(self, fgraph): |
572 | 553 | if fgraph.profile: |
|
0 commit comments