Skip to content

Commit e5d5dec

Browse files
committed
Speedup FusionOptimizer.elemwise_to_scalar
1 parent 0847ba7 commit e5d5dec

File tree

2 files changed

+23
-40
lines changed

2 files changed

+23
-40
lines changed

pytensor/scalar/basic.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -778,9 +778,11 @@ def get_scalar_type(dtype, cache: dict[str, ScalarType] = {}) -> ScalarType:
778778
This caches objects to save allocation and run time.
779779
780780
"""
781-
if dtype not in cache:
782-
cache[dtype] = ScalarType(dtype=dtype)
783-
return cache[dtype]
781+
try:
782+
return cache[dtype]
783+
except KeyError:
784+
cache[dtype] = res = ScalarType(dtype=dtype)
785+
return res
784786

785787

786788
# Register C code for ViewOp on Scalars.

pytensor/tensor/rewriting/elemwise.py

Lines changed: 18 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from pytensor.compile.mode import get_target_language
1515
from pytensor.configdefaults import config
1616
from pytensor.graph import FunctionGraph, Op
17-
from pytensor.graph.basic import Apply, Variable, ancestors
17+
from pytensor.graph.basic import Apply, Variable, ancestors, io_toposort
1818
from pytensor.graph.destroyhandler import DestroyHandler, inplace_candidates
1919
from pytensor.graph.features import ReplaceValidate
2020
from pytensor.graph.fg import Output
@@ -528,43 +528,24 @@ def add_requirements(self, fgraph):
528528

529529
@staticmethod
530530
def elemwise_to_scalar(inputs, outputs):
531-
replace_inputs = [(inp, inp.clone()) for inp in inputs]
532-
outputs = clone_replace(outputs, replace=replace_inputs)
533-
534-
inputs = [inp for _, inp in replace_inputs]
535-
fg = FunctionGraph(inputs=inputs, outputs=outputs, clone=False)
536-
middle_inputs = []
537-
538-
scalar_inputs = [
539-
ps.get_scalar_type(inp.type.dtype).make_variable() for inp in inputs
540-
]
541-
middle_scalar_inputs = []
542-
543-
for node in fg.toposort():
544-
node_scalar_inputs = []
545-
for inp in node.inputs:
546-
if inp in inputs:
547-
node_scalar_inputs.append(scalar_inputs[inputs.index(inp)])
548-
elif inp in middle_inputs:
549-
node_scalar_inputs.append(
550-
middle_scalar_inputs[middle_inputs.index(inp)]
531+
replacement = {
532+
inp: ps.get_scalar_type(inp.type.dtype).make_variable() for inp in inputs
533+
}
534+
for node in io_toposort(inputs, outputs):
535+
scalar_inputs = [replacement[inp] for inp in node.inputs]
536+
replacement.update(
537+
dict(
538+
zip(
539+
node.outputs,
540+
node.op.scalar_op.make_node(*scalar_inputs).outputs,
551541
)
552-
else:
553-
new_scalar_input = ps.get_scalar_type(
554-
inp.type.dtype
555-
).make_variable()
556-
node_scalar_inputs.append(new_scalar_input)
557-
middle_scalar_inputs.append(new_scalar_input)
558-
middle_inputs.append(inp)
559-
560-
new_scalar_node = node.op.scalar_op.make_node(*node_scalar_inputs)
561-
middle_scalar_inputs.append(new_scalar_node.outputs[0])
562-
middle_inputs.append(node.outputs[0])
563-
564-
scalar_outputs = [
565-
middle_scalar_inputs[middle_inputs.index(out)] for out in fg.outputs
566-
]
567-
return scalar_inputs, scalar_outputs
542+
)
543+
)
544+
545+
return (
546+
[replacement[inp] for inp in inputs],
547+
[replacement[out] for out in outputs],
548+
)
568549

569550
def apply(self, fgraph):
570551
if fgraph.profile:

0 commit comments

Comments
 (0)