@@ -621,65 +621,43 @@ def local_mul_switch_sink(fgraph, node):
621621 part of the graph.
622622
623623 """
624- for idx , i in enumerate (node .inputs ):
625- if i .owner and i .owner .op == switch :
626- switch_node = i .owner
627- try :
628- if (
629- get_underlying_scalar_constant_value (
630- switch_node .inputs [1 ], only_process_constants = True
631- )
632- == 0.0
633- ):
634- listmul = node .inputs [:idx ] + node .inputs [idx + 1 :]
635- fmul = mul (* ([* listmul , switch_node .inputs [2 ]]))
636-
637- # Copy over stacktrace for elementwise multiplication op
638- # from previous elementwise multiplication op.
639- # An error in the multiplication (e.g. errors due to
640- # inconsistent shapes), will point to the
641- # multiplication op.
642- copy_stack_trace (node .outputs , fmul )
643-
644- fct = [switch (switch_node .inputs [0 ], 0 , fmul )]
645- fct [0 ].tag .values_eq_approx = values_eq_approx_remove_nan
646-
647- # Copy over stacktrace for switch op from both previous
648- # elementwise multiplication op and previous switch op,
649- # because an error in this part can be caused by either
650- # of the two previous ops.
651- copy_stack_trace (node .outputs + switch_node .outputs , fct )
652- return fct
653- except NotScalarConstantError :
654- pass
655- try :
656- if (
657- get_underlying_scalar_constant_value (
658- switch_node .inputs [2 ], only_process_constants = True
659- )
660- == 0.0
661- ):
662- listmul = node .inputs [:idx ] + node .inputs [idx + 1 :]
663- fmul = mul (* ([* listmul , switch_node .inputs [1 ]]))
664- # Copy over stacktrace for elementwise multiplication op
665- # from previous elementwise multiplication op.
666- # An error in the multiplication (e.g. errors due to
667- # inconsistent shapes), will point to the
668- # multiplication op.
669- copy_stack_trace (node .outputs , fmul )
670-
671- fct = [switch (switch_node .inputs [0 ], fmul , 0 )]
672- fct [0 ].tag .values_eq_approx = values_eq_approx_remove_nan
673-
674- # Copy over stacktrace for switch op from both previous
675- # elementwise multiplication op and previous switch op,
676- # because an error in this part can be caused by either
677- # of the two previous ops.
678- copy_stack_trace (node .outputs + switch_node .outputs , fct )
679- return fct
680- except NotScalarConstantError :
681- pass
682- return False
624+ for mul_inp_idx , mul_inp in enumerate (node .inputs ):
625+ if mul_inp .owner and mul_inp .owner .op == switch :
626+ switch_node = mul_inp .owner
627+ # Look for a zero as the first or second branch of the switch
628+ for branch in range (2 ):
629+ zero_switch_input = switch_node .inputs [1 + branch ]
630+ if not get_unique_constant_value (zero_switch_input ) == 0.0 :
631+ continue
632+
633+ switch_cond = switch_node .inputs [0 ]
634+ other_switch_input = switch_node .inputs [1 + (1 - branch )]
635+
636+ listmul = list (node .inputs )
637+ listmul [mul_inp_idx ] = other_switch_input
638+ fmul = mul (* listmul )
639+
640+ # Copy over stacktrace for elementwise multiplication op
641+ # from previous elementwise multiplication op.
642+ # An error in the multiplication (e.g. errors due to
643+ # inconsistent shapes), will point to the
644+ # multiplication op.
645+ copy_stack_trace (node .outputs , fmul )
646+
647+ if branch == 0 :
648+ fct = switch (switch_cond , zero_switch_input , fmul )
649+ else :
650+ fct = switch (switch_cond , fmul , zero_switch_input )
651+
652+ # Tell debug_mode than the output is correct, even if nan disappear
653+ fct .tag .values_eq_approx = values_eq_approx_remove_nan
654+
655+ # Copy over stacktrace for switch op from both previous
656+ # elementwise multiplication op and previous switch op,
657+ # because an error in this part can be caused by either
658+ # of the two previous ops.
659+ copy_stack_trace (node .outputs + switch_node .outputs , fct )
660+ return [fct ]
683661
684662
685663@register_canonicalize
@@ -699,62 +677,39 @@ def local_div_switch_sink(fgraph, node):
699677 See `local_mul_switch_sink` for more details.
700678
701679 """
702- op = node .op
703- if node .inputs [0 ].owner and node .inputs [0 ].owner .op == switch :
704- switch_node = node .inputs [0 ].owner
705- try :
706- if (
707- get_underlying_scalar_constant_value (
708- switch_node .inputs [1 ], only_process_constants = True
709- )
710- == 0.0
711- ):
712- fdiv = op (switch_node .inputs [2 ], node .inputs [1 ])
713- # Copy over stacktrace for elementwise division op
714- # from previous elementwise multiplication op.
715- # An error in the division (e.g. errors due to
716- # inconsistent shapes or division by zero),
717- # will point to the new division op.
718- copy_stack_trace (node .outputs , fdiv )
719-
720- fct = [switch (switch_node .inputs [0 ], 0 , fdiv )]
721- fct [0 ].tag .values_eq_approx = values_eq_approx_remove_nan
722-
723- # Copy over stacktrace for switch op from both previous
724- # elementwise division op and previous switch op,
725- # because an error in this part can be caused by either
726- # of the two previous ops.
727- copy_stack_trace (node .outputs + switch_node .outputs , fct )
728- return fct
729- except NotScalarConstantError :
730- pass
731- try :
732- if (
733- get_underlying_scalar_constant_value (
734- switch_node .inputs [2 ], only_process_constants = True
735- )
736- == 0.0
737- ):
738- fdiv = op (switch_node .inputs [1 ], node .inputs [1 ])
739- # Copy over stacktrace for elementwise division op
740- # from previous elementwise multiplication op.
741- # An error in the division (e.g. errors due to
742- # inconsistent shapes or division by zero),
743- # will point to the new division op.
744- copy_stack_trace (node .outputs , fdiv )
745-
746- fct = [switch (switch_node .inputs [0 ], fdiv , 0 )]
747- fct [0 ].tag .values_eq_approx = values_eq_approx_remove_nan
680+ num , denom = node .inputs
748681
749- # Copy over stacktrace for switch op from both previous
750- # elementwise division op and previous switch op,
751- # because an error in this part can be caused by either
752- # of the two previous ops.
753- copy_stack_trace (node .outputs + switch_node .outputs , fct )
754- return fct
755- except NotScalarConstantError :
756- pass
757- return False
682+ if num .owner and num .owner .op == switch :
683+ switch_node = num .owner
684+ # Look for a zero as the first or second branch of the switch
685+ for branch in range (2 ):
686+ zero_switch_input = switch_node .inputs [1 + branch ]
687+ if not get_unique_constant_value (zero_switch_input ) == 0.0 :
688+ continue
689+
690+ switch_cond = switch_node .inputs [0 ]
691+ other_switch_input = switch_node .inputs [1 + (1 - branch )]
692+
693+ fdiv = node .op (other_switch_input , denom )
694+
695+ # Copy over stacktrace for elementwise division op
696+ # from previous elementwise multiplication op.
697+ # An error in the division (e.g. errors due to
698+ # inconsistent shapes or division by zero),
699+ # will point to the new division op.
700+ copy_stack_trace (node .outputs , fdiv )
701+
702+ fct = switch (switch_cond , zero_switch_input , fdiv )
703+
704+ # Tell debug_mode than the output is correct, even if nan disappear
705+ fct .tag .values_eq_approx = values_eq_approx_remove_nan
706+
707+ # Copy over stacktrace for switch op from both previous
708+ # elementwise division op and previous switch op,
709+ # because an error in this part can be caused by either
710+ # of the two previous ops.
711+ copy_stack_trace (node .outputs + switch_node .outputs , fct )
712+ return [fct ]
758713
759714
760715class AlgebraicCanonizer (NodeRewriter ):
0 commit comments