6060import numpy as np
6161
6262from pytensor .graph .traversal import toposort
63+ from pytensor .scalar import Add , Mul , Neg , Sub
6364from pytensor .tensor .rewriting .basic import register_specialize
6465
6566
100101from pytensor .tensor .math import (
101102 Dot ,
102103 _matmul ,
103- add ,
104104 mul ,
105- neg ,
106- sub ,
107105 variadic_add ,
108106)
109107from pytensor .tensor .rewriting .elemwise import local_dimshuffle_lift
@@ -237,22 +235,27 @@ def scaled(thing):
237235 rval .append (scaled (r ))
238236 return rval
239237
240- if maxclients and len (fgraph .clients [r ]) > maxclients :
238+ if (
239+ (r .owner is None )
240+ or (not isinstance (r .owner .op , Elemwise ))
241+ or (maxclients and len (fgraph .clients [r ]) > maxclients )
242+ ):
241243 rval .append ((scale , r ))
242244 return rval
243245
244- if r .owner and r .owner .op == sub :
246+ scalar_op = r .owner .op .scalar_op
247+ if isinstance (scalar_op , Sub ):
245248 _gemm_canonicalize (fgraph , r .owner .inputs [0 ], scale , rval , 1 )
246249 _gemm_canonicalize (fgraph , r .owner .inputs [1 ], - scale , rval , 1 )
247250
248- elif r . owner and r . owner . op == add :
251+ elif isinstance ( scalar_op , Add ) :
249252 for i in r .owner .inputs :
250253 _gemm_canonicalize (fgraph , i , scale , rval , 1 )
251254
252- elif r . owner and r . owner . op == neg :
255+ elif isinstance ( scalar_op , Neg ) :
253256 _gemm_canonicalize (fgraph , r .owner .inputs [0 ], - scale , rval , 1 )
254257
255- elif r . owner and r . owner . op == mul :
258+ elif isinstance ( scalar_op , Mul ) :
256259 scalars = []
257260 vectors = []
258261 matrices = []
@@ -460,35 +463,45 @@ def apply(self, fgraph):
460463 callbacks_before = fgraph .execute_callbacks_times .copy ()
461464 callback_before = fgraph .execute_callbacks_time
462465
463- nodelist = list (toposort (fgraph .outputs ))
466+ relevant_core_ops = (
467+ pytensor .scalar .Add
468+ | pytensor .scalar .Sub
469+ | pytensor .scalar .Neg
470+ | pytensor .scalar .Mul
471+ )
472+ nodelist = [
473+ a
474+ for a in toposort (fgraph .outputs )
475+ if (
476+ isinstance (a .op , Elemwise )
477+ and isinstance (a .op .scalar_op , relevant_core_ops )
478+ )
479+ ]
480+ if not nodelist :
481+ return None
482+
464483 nodelist .reverse ()
465484
466485 def on_import (new_node ):
467- if new_node is not node :
486+ if (
487+ new_node is not node
488+ and isinstance (new_node .op , Elemwise )
489+ and isinstance (new_node .op .scalar_op , relevant_core_ops )
490+ ):
468491 nodelist .append (new_node )
469492
470493 u = pytensor .graph .rewriting .basic .DispatchingFeature (
471494 on_import , None , None , name = "GemmOptimizer"
472495 )
473496 fgraph .attach_feature (u )
497+ fgraph_apply_nodes = fgraph .apply_nodes
474498 while did_something :
475499 nb_iter += 1
476500 t0 = time .perf_counter ()
477501 time_toposort += time .perf_counter () - t0
478502 did_something = False
479503 for node in nodelist :
480- if not (
481- isinstance (node .op , Elemwise )
482- and isinstance (
483- node .op .scalar_op ,
484- pytensor .scalar .Add
485- | pytensor .scalar .Sub
486- | pytensor .scalar .Neg
487- | pytensor .scalar .Mul ,
488- )
489- ):
490- continue
491- if node not in fgraph .apply_nodes :
504+ if node not in fgraph_apply_nodes :
492505 # This mean that we already removed this node from
493506 # the graph
494507 continue
@@ -502,7 +515,6 @@ def on_import(new_node):
502515 continue
503516 if new_outputs :
504517 new_outputs , old_dot22 = new_outputs
505- assert len (new_outputs ) == len (node .outputs )
506518 new_outputs [
507519 0
508520 ].tag .values_eq_approx = values_eq_approx_remove_inf_nan
@@ -518,8 +530,7 @@ def on_import(new_node):
518530 did_something = True
519531 nb_replacement += 1
520532 except InconsistencyError :
521- # TODO: retry other applications of gemm (see comment
522- # in _gemm_from_node)
533+ # TODO: retry other applications of gemm (see comment in _gemm_from_node)
523534 nb_inconsistency_replace += 1
524535 except ReplacementDidNotRemoveError :
525536 nb_replacement_didn_t_remove += 1
0 commit comments