2626 out2in ,
2727)
2828from pytensor .graph .rewriting .db import SequenceDB
29+ from pytensor .graph .rewriting .unify import OpPattern
2930from pytensor .graph .utils import InconsistencyError , MethodNotDefined
3031from pytensor .scalar .math import Grad2F1Loop , _grad_2f1_loop
3132from pytensor .tensor .basic import (
3738from pytensor .tensor .rewriting .basic import (
3839 alloc_like ,
3940 broadcasted_by ,
41+ elemwise_of ,
4042 register_canonicalize ,
4143 register_specialize ,
4244 register_stabilize ,
@@ -422,7 +424,14 @@ def local_useless_dimshuffle_makevector(fgraph, node):
422424
423425
424426@register_canonicalize
425- @node_rewriter ([Elemwise ])
427+ @node_rewriter (
428+ [
429+ elemwise_of (
430+ OpPattern (ps .ScalarOp , output_types_preference = ps .upgrade_to_float )
431+ ),
432+ elemwise_of (OpPattern (ps .ScalarOp , output_types_preference = ps .upcast_out )),
433+ ]
434+ )
426435def local_upcast_elemwise_constant_inputs (fgraph , node ):
427436 """This explicitly upcasts constant inputs to elemwise Ops, when
428437 those Ops do implicit upcasting anyway.
@@ -433,12 +442,6 @@ def local_upcast_elemwise_constant_inputs(fgraph, node):
433442 if len (node .outputs ) > 1 :
434443 return None
435444
436- if getattr (node .op .scalar_op , "output_types_preference" , None ) not in (
437- ps .upgrade_to_float ,
438- ps .upcast_out ,
439- ):
440- return None
441-
442445 # this is the kind of op that we can screw with the input
443446 # dtypes by upcasting explicitly
444447 [old_out ] = node .outputs
@@ -988,13 +991,9 @@ def print_profile(stream, prof, level=0):
988991
989992@register_canonicalize
990993@register_specialize
991- @node_rewriter ([Elemwise ])
994+ @node_rewriter ([elemwise_of ( ps . Composite ) ])
992995def local_useless_composite_outputs (fgraph , node ):
993996 """Remove inputs and outputs of Composite Ops that are not used anywhere."""
994- if not (
995- isinstance (node .op , Elemwise ) and isinstance (node .op .scalar_op , ps .Composite )
996- ):
997- return
998997 comp = node .op .scalar_op
999998 used_outputs_idxs = [
1000999 i for i , o_extern in enumerate (node .outputs ) if fgraph .clients [o_extern ]
@@ -1104,14 +1103,10 @@ def local_careduce_fusion(fgraph, node):
11041103 return [new_car_op (* elm_inputs )]
11051104
11061105
1107- @node_rewriter ([Elemwise ])
1106+ @node_rewriter ([elemwise_of ( ps . Composite ) ])
11081107def local_inline_composite_constants (fgraph , node ):
11091108 """Inline scalar constants in Composite graphs."""
11101109 composite_op = node .op .scalar_op
1111-
1112- if not isinstance (composite_op , ps .Composite ):
1113- return None
1114-
11151110 new_outer_inputs = []
11161111 new_inner_inputs = []
11171112 inner_replacements = {}
@@ -1287,14 +1282,9 @@ def _rebuild_partial_2f1grad_loop(node, wrt):
12871282
12881283
12891284@register_specialize
1290- @node_rewriter ([Elemwise ])
1285+ @node_rewriter ([elemwise_of ( Grad2F1Loop ) ])
12911286def local_useless_2f1grad_loop (fgraph , node ):
12921287 # Remove unused terms from the hyp2f1 grad loop
1293-
1294- loop_op = node .op .scalar_op
1295- if not isinstance (loop_op , Grad2F1Loop ):
1296- return
1297-
12981288 grad_related_vars = node .outputs [:- 4 ]
12991289 # Rewrite was already applied
13001290 if len (grad_related_vars ) // 3 != 3 :
@@ -1326,18 +1316,13 @@ def local_useless_2f1grad_loop(fgraph, node):
13261316 return replacements
13271317
13281318
1329- @node_rewriter ([Elemwise ])
1319+ @node_rewriter ([elemwise_of ( Grad2F1Loop ) ])
13301320def split_2f1grad_loop (fgraph , node ):
13311321 """
13321322 2f1grad loop has too many operands for Numpy frompyfunc code used by Elemwise nodes on python mode.
13331323
13341324 This rewrite splits it across 3 different operations. It is not needed if `local_useless_2f1grad_loop` was applied
13351325 """
1336- loop_op = node .op .scalar_op
1337-
1338- if not isinstance (loop_op , Grad2F1Loop ):
1339- return None
1340-
13411326 grad_related_vars = node .outputs [:- 4 ]
13421327 # local_useless_2f1grad_loop was used, we should be safe
13431328 if len (grad_related_vars ) // 3 != 3 :
0 commit comments