Skip to content

Commit 2ce0ce1

Browse files
committed
Note failing scan rewrite
1 parent 9a124ca commit 2ce0ce1

File tree

2 files changed

+24
-5
lines changed

2 files changed

+24
-5
lines changed

pytensor/scan/rewriting.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -658,10 +658,9 @@ def inner_sitsot_only_last_step_used(
658658
fgraph: FunctionGraph, var: Variable, scan_args: ScanArgs
659659
) -> bool:
660660
"""
661-
Given a inner nit-sot output of `Scan`, return ``True`` iff the outer
662-
nit-sot output has only one client and that client is a `Subtensor`
663-
instance that takes only the last step (last element along the first
664-
axis).
661+
Given a inner sit-sot output of `Scan`, return ``True`` iff the outer
662+
sit-sot output has only one client and that client is a `Subtensor`
663+
instance that takes only the last step (last element along the first axis).
665664
"""
666665
idx = scan_args.inner_out_sit_sot.index(var)
667666
outer_var = scan_args.outer_out_sit_sot[idx]
@@ -832,6 +831,14 @@ def scan_push_out_add(fgraph, node):
832831
Like `scan_push_out_seq`, this optimization aims to replace many operations
833832
on small tensors by few operations on large tensors. It can also lead to
834833
increased memory usage.
834+
835+
FIXME: This rewrite doesn't cover user defined graphs,
836+
since it doesn't account for the intermediate slice
837+
returned by the scan constructor for sit-sot (i.e., something like output[1:]).
838+
It only looks for `outputs[-1]` but the user will only ever write `outputs[1:][-1]`
839+
The relevant helper function is `inner_sitsot_only_last_step_used` which is only used by this rewrite
840+
Note this rewrite is registered before subtensor_merge, but even if it were after subtensor_merge is a mess
841+
and doesn't simplify to x[1:][-1] to x[-1] unless x length is statically known
835842
"""
836843
# Don't perform the optimization on `as_while` `Scan`s. Because these
837844
# `Scan`s don't run for a predetermined number of steps, handling them is
@@ -857,6 +864,7 @@ def scan_push_out_add(fgraph, node):
857864
isinstance(nd.op, Elemwise)
858865
and isinstance(nd.op.scalar_op, ps.Add)
859866
and nd.out in args.inner_out_sit_sot
867+
# FIXME: This function doesn't handle `sitsot_out[1:][-1]` pattern
860868
and inner_sitsot_only_last_step_used(fgraph, nd.out, args)
861869
):
862870
# Ensure that one of the input to the add is the output of
@@ -920,6 +928,7 @@ def scan_push_out_add(fgraph, node):
920928
# external Dot instead of the output of scan
921929
# Modify the outer graph to add the outer Dot
922930
outer_sitsot = new_scan_args.outer_out_sit_sot[sitsot_idx]
931+
# TODO: If we fix the FIXME above, we have to make sure we replace the last subtensor, not the immediate one
923932
subtensor_node = fgraph.clients[outer_sitsot][0][0]
924933
outer_sitsot_last_step = subtensor_node.outputs[0]
925934

tests/scan/test_rewriting.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -600,10 +600,12 @@ class TestPushOutAddScan:
600600
is used to compute the sum over the dot products between the corresponding
601601
elements of two list of matrices.
602602
603-
TODO FIXME XXX: These aren't real tests; they simply confirm that a few
603+
FIXME: These aren't real tests; they simply confirm that a few
604604
graph that could be relevant to the push-out optimizations can be compiled
605605
and evaluated. None of them confirm that a push-out optimization has been
606606
performed.
607+
608+
FIXME: The rewrite is indeed broken, probably fro a long while, see FIXME details in the respective rewrite
607609
"""
608610

609611
def test_sum_dot(self):
@@ -614,7 +616,15 @@ def test_sum_dot(self):
614616
sequences=[A.dimshuffle(0, 1, "x"), B.dimshuffle(0, "x", 1)],
615617
outputs_info=[pt.zeros_like(A)],
616618
)
619+
# FIXME: This `s.owner.inputs[0][-1]` is a hack, users will never do that.
620+
# They will do `s[-1]` which the rewrite fails to identify since it explicitly looks for a `scan_out[-1]`
621+
# instead of `scan_out[1:][-1]` that the user would define by writing `s[-1]`
622+
# It however, tests the only case the rewrite supports now
617623
f = function([A, B], S.owner.inputs[0][-1])
624+
has_scan = any(isinstance(node.op, Scan) for node in f.maker.fgraph.apply_nodes)
625+
# Rewrite is only triggered in fast_run mode
626+
assert has_scan if (config.mode == "FAST_COMPILE") else (not has_scan)
627+
618628
rng = np.random.default_rng(utt.fetch_seed())
619629
vA = rng.uniform(size=(5, 5)).astype(config.floatX)
620630
vB = rng.uniform(size=(5, 5)).astype(config.floatX)

0 commit comments

Comments
 (0)