|
14 | 14 | from pytensor.compile.mode import get_target_language |
15 | 15 | from pytensor.configdefaults import config |
16 | 16 | from pytensor.graph import FunctionGraph, Op |
17 | | -from pytensor.graph.basic import Apply, Variable, ancestors |
| 17 | +from pytensor.graph.basic import Apply, Variable, ancestors, io_toposort |
18 | 18 | from pytensor.graph.destroyhandler import DestroyHandler, inplace_candidates |
19 | 19 | from pytensor.graph.features import ReplaceValidate |
20 | 20 | from pytensor.graph.fg import Output |
@@ -528,43 +528,24 @@ def add_requirements(self, fgraph): |
528 | 528 |
|
529 | 529 | @staticmethod |
530 | 530 | def elemwise_to_scalar(inputs, outputs): |
531 | | - replace_inputs = [(inp, inp.clone()) for inp in inputs] |
532 | | - outputs = clone_replace(outputs, replace=replace_inputs) |
533 | | - |
534 | | - inputs = [inp for _, inp in replace_inputs] |
535 | | - fg = FunctionGraph(inputs=inputs, outputs=outputs, clone=False) |
536 | | - middle_inputs = [] |
537 | | - |
538 | | - scalar_inputs = [ |
539 | | - ps.get_scalar_type(inp.type.dtype).make_variable() for inp in inputs |
540 | | - ] |
541 | | - middle_scalar_inputs = [] |
542 | | - |
543 | | - for node in fg.toposort(): |
544 | | - node_scalar_inputs = [] |
545 | | - for inp in node.inputs: |
546 | | - if inp in inputs: |
547 | | - node_scalar_inputs.append(scalar_inputs[inputs.index(inp)]) |
548 | | - elif inp in middle_inputs: |
549 | | - node_scalar_inputs.append( |
550 | | - middle_scalar_inputs[middle_inputs.index(inp)] |
| 531 | + replacement = { |
| 532 | + inp: ps.get_scalar_type(inp.type.dtype).make_variable() for inp in inputs |
| 533 | + } |
| 534 | + for node in io_toposort(inputs, outputs): |
| 535 | + scalar_inputs = [replacement[inp] for inp in node.inputs] |
| 536 | + replacement.update( |
| 537 | + dict( |
| 538 | + zip( |
| 539 | + node.outputs, |
| 540 | + node.op.scalar_op.make_node(*scalar_inputs).outputs, |
551 | 541 | ) |
552 | | - else: |
553 | | - new_scalar_input = ps.get_scalar_type( |
554 | | - inp.type.dtype |
555 | | - ).make_variable() |
556 | | - node_scalar_inputs.append(new_scalar_input) |
557 | | - middle_scalar_inputs.append(new_scalar_input) |
558 | | - middle_inputs.append(inp) |
559 | | - |
560 | | - new_scalar_node = node.op.scalar_op.make_node(*node_scalar_inputs) |
561 | | - middle_scalar_inputs.append(new_scalar_node.outputs[0]) |
562 | | - middle_inputs.append(node.outputs[0]) |
563 | | - |
564 | | - scalar_outputs = [ |
565 | | - middle_scalar_inputs[middle_inputs.index(out)] for out in fg.outputs |
566 | | - ] |
567 | | - return scalar_inputs, scalar_outputs |
| 542 | + ) |
| 543 | + ) |
| 544 | + |
| 545 | + return ( |
| 546 | + [replacement[inp] for inp in inputs], |
| 547 | + [replacement[out] for out in outputs], |
| 548 | + ) |
568 | 549 |
|
569 | 550 | def apply(self, fgraph): |
570 | 551 | if fgraph.profile: |
|
0 commit comments