Skip to content

Commit f3c5f0e

Browse files
committed
Speedup _gemm_canonicalize
1 parent 3e9901a commit f3c5f0e

File tree

1 file changed

+36
-25
lines changed

1 file changed

+36
-25
lines changed

pytensor/tensor/rewriting/blas.py

Lines changed: 36 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
import numpy as np
6161

6262
from pytensor.graph.traversal import toposort
63+
from pytensor.scalar import Add, Mul, Neg, Sub
6364
from pytensor.tensor.rewriting.basic import register_specialize
6465

6566

@@ -100,10 +101,7 @@
100101
from pytensor.tensor.math import (
101102
Dot,
102103
_matmul,
103-
add,
104104
mul,
105-
neg,
106-
sub,
107105
variadic_add,
108106
)
109107
from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift
@@ -237,22 +235,27 @@ def scaled(thing):
237235
rval.append(scaled(r))
238236
return rval
239237

240-
if maxclients and len(fgraph.clients[r]) > maxclients:
238+
if (
239+
(r.owner is None)
240+
or (not isinstance(r.owner.op, Elemwise))
241+
or (maxclients and len(fgraph.clients[r]) > maxclients)
242+
):
241243
rval.append((scale, r))
242244
return rval
243245

244-
if r.owner and r.owner.op == sub:
246+
scalar_op = r.owner.op.scalar_op
247+
if isinstance(scalar_op, Sub):
245248
_gemm_canonicalize(fgraph, r.owner.inputs[0], scale, rval, 1)
246249
_gemm_canonicalize(fgraph, r.owner.inputs[1], -scale, rval, 1)
247250

248-
elif r.owner and r.owner.op == add:
251+
elif isinstance(scalar_op, Add):
249252
for i in r.owner.inputs:
250253
_gemm_canonicalize(fgraph, i, scale, rval, 1)
251254

252-
elif r.owner and r.owner.op == neg:
255+
elif isinstance(scalar_op, Neg):
253256
_gemm_canonicalize(fgraph, r.owner.inputs[0], -scale, rval, 1)
254257

255-
elif r.owner and r.owner.op == mul:
258+
elif isinstance(scalar_op, Mul):
256259
scalars = []
257260
vectors = []
258261
matrices = []
@@ -460,35 +463,45 @@ def apply(self, fgraph):
460463
callbacks_before = fgraph.execute_callbacks_times.copy()
461464
callback_before = fgraph.execute_callbacks_time
462465

463-
nodelist = list(toposort(fgraph.outputs))
466+
relevant_core_ops = (
467+
pytensor.scalar.Add
468+
| pytensor.scalar.Sub
469+
| pytensor.scalar.Neg
470+
| pytensor.scalar.Mul
471+
)
472+
nodelist = [
473+
a
474+
for a in toposort(fgraph.outputs)
475+
if (
476+
isinstance(a.op, Elemwise)
477+
and isinstance(a.op.scalar_op, relevant_core_ops)
478+
)
479+
]
480+
if not nodelist:
481+
return None
482+
464483
nodelist.reverse()
465484

466485
def on_import(new_node):
467-
if new_node is not node:
486+
if (
487+
new_node is not node
488+
and isinstance(new_node.op, Elemwise)
489+
and isinstance(new_node.op.scalar_op, relevant_core_ops)
490+
):
468491
nodelist.append(new_node)
469492

470493
u = pytensor.graph.rewriting.basic.DispatchingFeature(
471494
on_import, None, None, name="GemmOptimizer"
472495
)
473496
fgraph.attach_feature(u)
497+
fgraph_apply_nodes = fgraph.apply_nodes
474498
while did_something:
475499
nb_iter += 1
476500
t0 = time.perf_counter()
477501
time_toposort += time.perf_counter() - t0
478502
did_something = False
479503
for node in nodelist:
480-
if not (
481-
isinstance(node.op, Elemwise)
482-
and isinstance(
483-
node.op.scalar_op,
484-
pytensor.scalar.Add
485-
| pytensor.scalar.Sub
486-
| pytensor.scalar.Neg
487-
| pytensor.scalar.Mul,
488-
)
489-
):
490-
continue
491-
if node not in fgraph.apply_nodes:
504+
if node not in fgraph_apply_nodes:
492505
# This mean that we already removed this node from
493506
# the graph
494507
continue
@@ -502,7 +515,6 @@ def on_import(new_node):
502515
continue
503516
if new_outputs:
504517
new_outputs, old_dot22 = new_outputs
505-
assert len(new_outputs) == len(node.outputs)
506518
new_outputs[
507519
0
508520
].tag.values_eq_approx = values_eq_approx_remove_inf_nan
@@ -518,8 +530,7 @@ def on_import(new_node):
518530
did_something = True
519531
nb_replacement += 1
520532
except InconsistencyError:
521-
# TODO: retry other applications of gemm (see comment
522-
# in _gemm_from_node)
533+
# TODO: retry other applications of gemm (see comment in _gemm_from_node)
523534
nb_inconsistency_replace += 1
524535
except ReplacementDidNotRemoveError:
525536
nb_replacement_didn_t_remove += 1

0 commit comments

Comments
 (0)