@@ -658,10 +658,9 @@ def inner_sitsot_only_last_step_used(
658658 fgraph : FunctionGraph , var : Variable , scan_args : ScanArgs
659659) -> bool :
660660 """
661- Given a inner nit-sot output of `Scan`, return ``True`` iff the outer
662- nit-sot output has only one client and that client is a `Subtensor`
663- instance that takes only the last step (last element along the first
664- axis).
661+ Given a inner sit-sot output of `Scan`, return ``True`` iff the outer
662+ sit-sot output has only one client and that client is a `Subtensor`
663+ instance that takes only the last step (last element along the first axis).
665664 """
666665 idx = scan_args .inner_out_sit_sot .index (var )
667666 outer_var = scan_args .outer_out_sit_sot [idx ]
@@ -832,6 +831,14 @@ def scan_push_out_add(fgraph, node):
832831 Like `scan_push_out_seq`, this optimization aims to replace many operations
833832 on small tensors by few operations on large tensors. It can also lead to
834833 increased memory usage.
834+
835+ FIXME: This rewrite doesn't cover user defined graphs,
836+ since it doesn't account for the intermediate slice
837+ returned by the scan constructor for sit-sot (i.e., something like output[1:]).
838+ It only looks for `outputs[-1]` but the user will only ever write `outputs[1:][-1]`
839+ The relevant helper function is `inner_sitsot_only_last_step_used` which is only used by this rewrite
840+ Note this rewrite is registered before subtensor_merge, but even if it were after subtensor_merge is a mess
841+ and doesn't simplify to x[1:][-1] to x[-1] unless x length is statically known
835842 """
836843 # Don't perform the optimization on `as_while` `Scan`s. Because these
837844 # `Scan`s don't run for a predetermined number of steps, handling them is
@@ -857,6 +864,7 @@ def scan_push_out_add(fgraph, node):
857864 isinstance (nd .op , Elemwise )
858865 and isinstance (nd .op .scalar_op , ps .Add )
859866 and nd .out in args .inner_out_sit_sot
867+ # FIXME: This function doesn't handle `sitsot_out[1:][-1]` pattern
860868 and inner_sitsot_only_last_step_used (fgraph , nd .out , args )
861869 ):
862870 # Ensure that one of the input to the add is the output of
@@ -920,6 +928,7 @@ def scan_push_out_add(fgraph, node):
920928 # external Dot instead of the output of scan
921929 # Modify the outer graph to add the outer Dot
922930 outer_sitsot = new_scan_args .outer_out_sit_sot [sitsot_idx ]
931+ # TODO: If we fix the FIXME above, we have to make sure we replace the last subtensor, not the immediate one
923932 subtensor_node = fgraph .clients [outer_sitsot ][0 ][0 ]
924933 outer_sitsot_last_step = subtensor_node .outputs [0 ]
925934
0 commit comments