Skip to content

Commit 69237f3

Browse files
committed
Speedup FusionOptimizer.elemwise_to_scalar
1 parent 820d99d commit 69237f3

File tree

2 files changed

+23
-40
lines changed

2 files changed

+23
-40
lines changed

pytensor/scalar/basic.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -779,9 +779,11 @@ def get_scalar_type(dtype, cache: dict[str, ScalarType] = {}) -> ScalarType:
779779
This caches objects to save allocation and run time.
780780
781781
"""
782-
if dtype not in cache:
783-
cache[dtype] = ScalarType(dtype=dtype)
784-
return cache[dtype]
782+
try:
783+
return cache[dtype]
784+
except KeyError:
785+
cache[dtype] = res = ScalarType(dtype=dtype)
786+
return res
785787

786788

787789
# Register C code for ViewOp on Scalars.

pytensor/tensor/rewriting/elemwise.py

Lines changed: 18 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
)
2929
from pytensor.graph.rewriting.db import SequenceDB
3030
from pytensor.graph.rewriting.unify import OpPattern
31-
from pytensor.graph.traversal import ancestors
31+
from pytensor.graph.traversal import ancestors, toposort
3232
from pytensor.graph.utils import InconsistencyError, MethodNotDefined
3333
from pytensor.scalar.math import Grad2F1Loop, _grad_2f1_loop
3434
from pytensor.tensor.basic import (
@@ -530,43 +530,24 @@ def add_requirements(self, fgraph):
530530

531531
@staticmethod
532532
def elemwise_to_scalar(inputs, outputs):
533-
replace_inputs = [(inp, inp.clone()) for inp in inputs]
534-
outputs = clone_replace(outputs, replace=replace_inputs)
535-
536-
inputs = [inp for _, inp in replace_inputs]
537-
fg = FunctionGraph(inputs=inputs, outputs=outputs, clone=False)
538-
middle_inputs = []
539-
540-
scalar_inputs = [
541-
ps.get_scalar_type(inp.type.dtype).make_variable() for inp in inputs
542-
]
543-
middle_scalar_inputs = []
544-
545-
for node in fg.toposort():
546-
node_scalar_inputs = []
547-
for inp in node.inputs:
548-
if inp in inputs:
549-
node_scalar_inputs.append(scalar_inputs[inputs.index(inp)])
550-
elif inp in middle_inputs:
551-
node_scalar_inputs.append(
552-
middle_scalar_inputs[middle_inputs.index(inp)]
533+
replacement = {
534+
inp: ps.get_scalar_type(inp.type.dtype).make_variable() for inp in inputs
535+
}
536+
for node in toposort(outputs, blockers=inputs):
537+
scalar_inputs = [replacement[inp] for inp in node.inputs]
538+
replacement.update(
539+
dict(
540+
zip(
541+
node.outputs,
542+
node.op.scalar_op.make_node(*scalar_inputs).outputs,
553543
)
554-
else:
555-
new_scalar_input = ps.get_scalar_type(
556-
inp.type.dtype
557-
).make_variable()
558-
node_scalar_inputs.append(new_scalar_input)
559-
middle_scalar_inputs.append(new_scalar_input)
560-
middle_inputs.append(inp)
561-
562-
new_scalar_node = node.op.scalar_op.make_node(*node_scalar_inputs)
563-
middle_scalar_inputs.append(new_scalar_node.outputs[0])
564-
middle_inputs.append(node.outputs[0])
565-
566-
scalar_outputs = [
567-
middle_scalar_inputs[middle_inputs.index(out)] for out in fg.outputs
568-
]
569-
return scalar_inputs, scalar_outputs
544+
)
545+
)
546+
547+
return (
548+
[replacement[inp] for inp in inputs],
549+
[replacement[out] for out in outputs],
550+
)
570551

571552
def apply(self, fgraph):
572553
if fgraph.profile:

0 commit comments

Comments
 (0)