Skip to content

Commit 52296ff

Browse files
committed
Allow non-shared untraced SIT-SOT
1 parent bc72ef3 commit 52296ff

File tree

8 files changed

+341
-207
lines changed

8 files changed

+341
-207
lines changed

pytensor/link/jax/dispatch/scan.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def scan(*outer_inputs):
6060
mit_mot_init,
6161
mit_sot_init,
6262
sit_sot_init,
63-
op.outer_shared(outer_inputs),
63+
op.outer_untraced_sit_sot(outer_inputs),
6464
op.outer_non_seqs(outer_inputs),
6565
) # JAX `init`
6666

@@ -118,7 +118,7 @@ def inner_func_outs_to_jax_outs(
118118
):
119119
"""Convert inner_scan_func outputs into format expected by JAX scan.
120120
121-
old_carry + (MIT-SOT_outs, SIT-SOT_outs, NIT-SOT_outs, shared_outs) -> (new_carry, ys)
121+
old_carry + (MIT-SOT_outs, SIT-SOT_outs, NIT-SOT_outs, untraced_SIT-SOT_outs) -> (new_carry, ys)
122122
"""
123123
(
124124
i,
@@ -133,7 +133,7 @@ def inner_func_outs_to_jax_outs(
133133
new_mit_sot_vals = op.inner_mitsot_outs(inner_scan_outs)
134134
new_sit_sot = op.inner_sitsot_outs(inner_scan_outs)
135135
new_nit_sot = op.inner_nitsot_outs(inner_scan_outs)
136-
new_shared = op.inner_shared_outs(inner_scan_outs)
136+
new_shared = op.inner_untraced_sit_sot_outs(inner_scan_outs)
137137

138138
# New carry for next step
139139
# Update MIT-MOT buffer at positions indicated by output taps

pytensor/link/numba/dispatch/scan.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
108108
outer_in_mit_sot_names = op.outer_mitsot(outer_in_names)
109109
outer_in_sit_sot_names = op.outer_sitsot(outer_in_names)
110110
outer_in_nit_sot_names = op.outer_nitsot(outer_in_names)
111-
outer_in_shared_names = op.outer_shared(outer_in_names)
111+
outer_in_shared_names = op.outer_untraced_sit_sot(outer_in_names)
112112
outer_in_non_seqs_names = op.outer_non_seqs(outer_in_names)
113113

114114
# These are all the outer-input names that have produce outputs/have output
@@ -204,11 +204,9 @@ def add_inner_in_expr(
204204

205205
# Inner-outputs consist of:
206206
# mit-mot-outputs + mit-sot-outputs + sit-sot-outputs + nit-sots +
207-
# shared-outputs [+ while-condition]
207+
# untraced-sit-sot-outputs [+ while-condition]
208208
inner_output_names = [f"inner_out_{i}" for i in range(len(op.inner_outputs))]
209209

210-
# inner_out_shared_names = op.inner_shared_outs(inner_output_names)
211-
212210
# The assignment statements that copy inner-outputs into the outer-outputs
213211
# storage
214212
inner_out_to_outer_in_stmts: list[str] = []

pytensor/scan/basic.py

Lines changed: 89 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import typing
12
import warnings
23
from itertools import chain
34

@@ -11,6 +12,7 @@
1112
from pytensor.graph.op import get_test_value
1213
from pytensor.graph.replace import clone_replace
1314
from pytensor.graph.traversal import explicit_graph_inputs
15+
from pytensor.graph.type import HasShape
1416
from pytensor.graph.utils import MissingInputError, TestValueError
1517
from pytensor.scan.op import Scan, ScanInfo
1618
from pytensor.scan.utils import expand_empty, safe_new, until
@@ -22,6 +24,10 @@
2224
from pytensor.updates import OrderedUpdates
2325

2426

27+
if typing.TYPE_CHECKING:
28+
from pytensor.tensor.type import TensorVariable
29+
30+
2531
def get_updates_and_outputs(ls):
2632
"""Recognize and order the updates, outputs, and stopping condition for a `Scan`.
2733
@@ -469,7 +475,7 @@ def wrap_into_list(x):
469475

470476
# Make sure we get rid of numpy arrays or ints or anything like that
471477
# passed as inputs to scan
472-
non_seqs = []
478+
non_seqs: list[Variable] = []
473479
for elem in wrap_into_list(non_sequences):
474480
if not isinstance(elem, Variable):
475481
non_seqs.append(pt.as_tensor_variable(elem))
@@ -685,10 +691,10 @@ def wrap_into_list(x):
685691

686692
# MIT_MOT -- not provided by the user only by the grad function
687693
n_mit_mot = 0
688-
mit_mot_scan_inputs = []
689-
mit_mot_inner_inputs = []
690-
mit_mot_inner_outputs = []
691-
mit_mot_out_slices = []
694+
mit_mot_scan_inputs: list[TensorVariable] = []
695+
mit_mot_inner_inputs: list[TensorVariable] = []
696+
mit_mot_inner_outputs: list[TensorVariable] = []
697+
mit_mot_out_slices: list[TensorVariable] = []
692698

693699
# SIT_SOT -- provided by the user
694700
n_mit_sot = 0
@@ -706,6 +712,12 @@ def wrap_into_list(x):
706712
sit_sot_inner_outputs = []
707713
sit_sot_rightOrder = []
708714

715+
n_untraced_sit_sot_outs = 0
716+
untraced_sit_sot_scan_inputs = []
717+
untraced_sit_sot_inner_inputs = []
718+
untraced_sit_sot_inner_outputs = []
719+
untraced_sit_sot_rightOrder = []
720+
709721
# go through outputs picking up time slices as needed
710722
for i, init_out in enumerate(outs_info):
711723
# Note that our convention dictates that if an output uses
@@ -741,17 +753,35 @@ def wrap_into_list(x):
741753
# We need now to allocate space for storing the output and copy
742754
# the initial state over. We do this using the expand function
743755
# defined in scan utils
744-
sit_sot_scan_inputs.append(
745-
expand_empty(
746-
shape_padleft(actual_arg),
747-
actual_n_steps,
756+
if isinstance(actual_arg.type, HasShape):
757+
sit_sot_scan_inputs.append(
758+
expand_empty(
759+
shape_padleft(actual_arg),
760+
actual_n_steps,
761+
)
748762
)
749-
)
763+
sit_sot_inner_slices.append(actual_arg)
750764

751-
sit_sot_inner_slices.append(actual_arg)
752-
sit_sot_inner_inputs.append(arg)
753-
sit_sot_rightOrder.append(i)
754-
n_sit_sot += 1
765+
sit_sot_inner_inputs.append(arg)
766+
sit_sot_rightOrder.append(i)
767+
n_sit_sot += 1
768+
else:
769+
# Assume variables without shape cannot be stacked (e.g., RNG variables)
770+
# Because this is new, issue a warning to inform the user, except for RNG, which were the main reason for this feature
771+
from pytensor.tensor.random.type import RandomType
772+
773+
if not isinstance(arg.type, RandomType):
774+
warnings.warn(
775+
(
776+
f"Output {actual_arg} (index {i}) with type {actual_arg.type} will be treated as untraced variable in scan. "
777+
"Only the last value will be returned, not the entire sequence."
778+
),
779+
UserWarning,
780+
)
781+
untraced_sit_sot_scan_inputs.append(actual_arg)
782+
untraced_sit_sot_inner_inputs.append(arg)
783+
n_untraced_sit_sot_outs += 1
784+
untraced_sit_sot_rightOrder.append(i)
755785

756786
elif init_out.get("taps", None):
757787
if np.any(np.array(init_out.get("taps", [])) > 0):
@@ -802,10 +832,11 @@ def wrap_into_list(x):
802832
# a map); in that case we do not have to do anything ..
803833

804834
# Re-order args
805-
max_mit_sot = np.max([-1, *mit_sot_rightOrder]) + 1
806-
max_sit_sot = np.max([-1, *sit_sot_rightOrder]) + 1
807-
n_elems = np.max([max_mit_sot, max_sit_sot])
808-
_ordered_args = [[] for x in range(n_elems)]
835+
max_mit_sot = max(mit_sot_rightOrder, default=-1) + 1
836+
max_sit_sot = max(sit_sot_rightOrder, default=-1) + 1
837+
max_untraced_sit_sot_outs = max(untraced_sit_sot_rightOrder, default=-1) + 1
838+
n_elems = np.max((max_mit_sot, max_sit_sot, max_untraced_sit_sot_outs))
839+
_ordered_args: list[list[Variable]] = [[] for x in range(n_elems)]
809840
offset = 0
810841
for idx in range(n_mit_sot):
811842
n_inputs = len(mit_sot_tap_array[idx])
@@ -825,6 +856,11 @@ def wrap_into_list(x):
825856
else:
826857
_ordered_args[sit_sot_rightOrder[idx]] = [sit_sot_inner_inputs[idx]]
827858

859+
for idx in range(n_untraced_sit_sot_outs):
860+
_ordered_args[untraced_sit_sot_rightOrder[idx]] = [
861+
untraced_sit_sot_inner_inputs[idx]
862+
]
863+
828864
ordered_args = list(chain.from_iterable(_ordered_args))
829865
if single_step_requested:
830866
args = inner_slices + ordered_args + non_seqs
@@ -939,18 +975,19 @@ def wrap_into_list(x):
939975
if "taps" in out and out["taps"] != [-1]:
940976
mit_sot_inner_outputs.append(outputs[i])
941977

942-
# Step 5.2 Outputs with tap equal to -1
978+
# Step 5.2 Outputs with tap equal to -1 (traced and untraced)
943979
for i, out in enumerate(outs_info):
944980
if "taps" in out and out["taps"] == [-1]:
945-
sit_sot_inner_outputs.append(outputs[i])
981+
output = outputs[i]
982+
if isinstance(output.type, HasShape):
983+
sit_sot_inner_outputs.append(output)
984+
else:
985+
untraced_sit_sot_inner_outputs.append(output)
946986

947987
# Step 5.3 Outputs that correspond to update rules of shared variables
988+
# This whole special logic for shared variables is deprecated
989+
sit_sot_shared: list[Variable] = []
948990
inner_replacements = {}
949-
n_shared_outs = 0
950-
shared_scan_inputs = []
951-
shared_inner_inputs = []
952-
shared_inner_outputs = []
953-
sit_sot_shared = []
954991
no_update_shared_inputs = []
955992
for input in dummy_inputs:
956993
if not isinstance(input.variable, SharedVariable):
@@ -976,8 +1013,8 @@ def wrap_into_list(x):
9761013

9771014
new_var = safe_new(input.variable)
9781015

979-
if getattr(input.variable, "name", None) is not None:
980-
new_var.name = input.variable.name + "_copy"
1016+
if input.variable.name is not None:
1017+
new_var.name = f"{input.variable.name}_copy"
9811018

9821019
inner_replacements[input.variable] = new_var
9831020

@@ -1003,10 +1040,10 @@ def wrap_into_list(x):
10031040
sit_sot_shared.append(input.variable)
10041041

10051042
else:
1006-
shared_inner_inputs.append(new_var)
1007-
shared_scan_inputs.append(input.variable)
1008-
shared_inner_outputs.append(input.update)
1009-
n_shared_outs += 1
1043+
untraced_sit_sot_inner_inputs.append(new_var)
1044+
untraced_sit_sot_scan_inputs.append(input.variable)
1045+
untraced_sit_sot_inner_outputs.append(input.update)
1046+
n_untraced_sit_sot_outs += 1
10101047
else:
10111048
no_update_shared_inputs.append(input)
10121049

@@ -1071,7 +1108,7 @@ def wrap_into_list(x):
10711108
+ mit_mot_inner_inputs
10721109
+ mit_sot_inner_inputs
10731110
+ sit_sot_inner_inputs
1074-
+ shared_inner_inputs
1111+
+ untraced_sit_sot_inner_inputs
10751112
+ other_shared_inner_args
10761113
+ other_inner_args
10771114
)
@@ -1081,7 +1118,7 @@ def wrap_into_list(x):
10811118
+ mit_sot_inner_outputs
10821119
+ sit_sot_inner_outputs
10831120
+ nit_sot_inner_outputs
1084-
+ shared_inner_outputs
1121+
+ untraced_sit_sot_inner_outputs
10851122
)
10861123
if condition is not None:
10871124
inner_outs.append(condition)
@@ -1101,7 +1138,7 @@ def wrap_into_list(x):
11011138
mit_mot_out_slices=tuple(tuple(v) for v in mit_mot_out_slices),
11021139
mit_sot_in_slices=tuple(tuple(v) for v in mit_sot_tap_array),
11031140
sit_sot_in_slices=tuple((-1,) for x in range(n_sit_sot)),
1104-
n_shared_outs=n_shared_outs,
1141+
n_untraced_sit_sot_outs=n_untraced_sit_sot_outs,
11051142
n_nit_sot=n_nit_sot,
11061143
n_non_seqs=len(other_shared_inner_args) + len(other_inner_args),
11071144
as_while=as_while,
@@ -1127,7 +1164,7 @@ def wrap_into_list(x):
11271164
+ mit_mot_scan_inputs
11281165
+ mit_sot_scan_inputs
11291166
+ sit_sot_scan_inputs
1130-
+ shared_scan_inputs
1167+
+ untraced_sit_sot_scan_inputs
11311168
+ [actual_n_steps for x in range(n_nit_sot)]
11321169
+ other_shared_scan_args
11331170
+ other_scan_args
@@ -1173,13 +1210,26 @@ def remove_dimensions(outs, offsets=None):
11731210
nit_sot_outs = remove_dimensions(scan_outs[offset : offset + n_nit_sot])
11741211

11751212
offset += n_nit_sot
1176-
for idx, update_rule in enumerate(scan_outs[offset : offset + n_shared_outs]):
1177-
update_map[shared_scan_inputs[idx]] = update_rule
11781213

1179-
_scan_out_list = mit_sot_outs + sit_sot_outs + nit_sot_outs
1214+
# Support for explicit untraced sit_sot
1215+
n_explicit_untraced_sit_sot_outs = len(untraced_sit_sot_rightOrder)
1216+
untraced_sit_sot_outs = scan_outs[
1217+
offset : offset + n_explicit_untraced_sit_sot_outs
1218+
]
1219+
1220+
offset += n_explicit_untraced_sit_sot_outs
1221+
for idx, update_rule in enumerate(scan_outs[offset:]):
1222+
update_map[untraced_sit_sot_scan_inputs[idx]] = update_rule
1223+
1224+
_scan_out_list = mit_sot_outs + sit_sot_outs + nit_sot_outs + untraced_sit_sot_outs
11801225
# Step 10. I need to reorder the outputs to be in the order expected by
11811226
# the user
1182-
rightOrder = mit_sot_rightOrder + sit_sot_rightOrder + nit_sot_rightOrder
1227+
rightOrder = (
1228+
mit_sot_rightOrder
1229+
+ sit_sot_rightOrder
1230+
+ untraced_sit_sot_rightOrder
1231+
+ nit_sot_rightOrder
1232+
)
11831233
scan_out_list = [None] * len(rightOrder)
11841234
for idx, pos in enumerate(rightOrder):
11851235
if pos >= 0:

0 commit comments

Comments
 (0)