@@ -220,9 +220,6 @@ def scan_push_out_non_seq(fgraph, node):
220220 it to the outer function to be executed only once, before the `Scan` `Op`,
221221 reduces the amount of computation that needs to be performed.
222222 """
223- if not isinstance (node .op , Scan ):
224- return False
225-
226223 node_inputs , node_outputs = node .op .inner_inputs , node .op .inner_outputs
227224
228225 local_fgraph_topo = io_toposort (node_inputs , node_outputs )
@@ -430,9 +427,6 @@ def scan_push_out_seq(fgraph, node):
430427 many times on many smaller tensors. In many cases, this optimization can
431428 increase memory usage but, in some specific cases, it can also decrease it.
432429 """
433- if not isinstance (node .op , Scan ):
434- return False
435-
436430 node_inputs , node_outputs = node .op .inner_inputs , node .op .inner_outputs
437431
438432 local_fgraph_topo = io_toposort (node_inputs , node_outputs )
@@ -696,7 +690,6 @@ def push_out_inner_vars(
696690 old_scan_args : ScanArgs ,
697691) -> tuple [list [Variable ], ScanArgs , dict [Variable , Variable ]]:
698692 tmp_outer_vars : list [Variable | None ] = []
699- new_scan_node = old_scan_node
700693 new_scan_args = old_scan_args
701694 replacements : dict [Variable , Variable ] = {}
702695
@@ -843,10 +836,11 @@ def scan_push_out_add(fgraph, node):
843836 # Don't perform the optimization on `as_while` `Scan`s. Because these
844837 # `Scan`s don't run for a predetermined number of steps, handling them is
845838 # more complicated and this optimization doesn't support it at the moment.
846- if not (isinstance (node .op , Scan ) and not node .op .info .as_while ):
839+ op = node .op
840+ if op .info .as_while :
847841 return False
848842
849- op = node . op
843+ # apply_ancestors(args.inner_outputs)
850844
851845 # Use `ScanArgs` to parse the inputs and outputs of scan for ease of
852846 # use
0 commit comments