Skip to content

Commit 9bc2a2f

Browse files
committed
Remove unnecessary checks and unused variable in Scan rewrites
1 parent 2ce0ce1 commit 9bc2a2f

File tree

1 file changed

+3
-9
lines changed

1 file changed

+3
-9
lines changed

pytensor/scan/rewriting.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -220,9 +220,6 @@ def scan_push_out_non_seq(fgraph, node):
220220
it to the outer function to be executed only once, before the `Scan` `Op`,
221221
reduces the amount of computation that needs to be performed.
222222
"""
223-
if not isinstance(node.op, Scan):
224-
return False
225-
226223
node_inputs, node_outputs = node.op.inner_inputs, node.op.inner_outputs
227224

228225
local_fgraph_topo = io_toposort(node_inputs, node_outputs)
@@ -430,9 +427,6 @@ def scan_push_out_seq(fgraph, node):
430427
many times on many smaller tensors. In many cases, this optimization can
431428
increase memory usage but, in some specific cases, it can also decrease it.
432429
"""
433-
if not isinstance(node.op, Scan):
434-
return False
435-
436430
node_inputs, node_outputs = node.op.inner_inputs, node.op.inner_outputs
437431

438432
local_fgraph_topo = io_toposort(node_inputs, node_outputs)
@@ -696,7 +690,6 @@ def push_out_inner_vars(
696690
old_scan_args: ScanArgs,
697691
) -> tuple[list[Variable], ScanArgs, dict[Variable, Variable]]:
698692
tmp_outer_vars: list[Variable | None] = []
699-
new_scan_node = old_scan_node
700693
new_scan_args = old_scan_args
701694
replacements: dict[Variable, Variable] = {}
702695

@@ -843,10 +836,11 @@ def scan_push_out_add(fgraph, node):
843836
# Don't perform the optimization on `as_while` `Scan`s. Because these
844837
# `Scan`s don't run for a predetermined number of steps, handling them is
845838
# more complicated and this optimization doesn't support it at the moment.
846-
if not (isinstance(node.op, Scan) and not node.op.info.as_while):
839+
op = node.op
840+
if op.info.as_while:
847841
return False
848842

849-
op = node.op
843+
# apply_ancestors(args.inner_outputs)
850844

851845
# Use `ScanArgs` to parse the inputs and outputs of scan for ease of
852846
# use

0 commit comments

Comments
 (0)