Skip to content

Commit 5fc02e1

Browse files
committed
Allow non-shared untraced SIT-SOT
1 parent bc72ef3 commit 5fc02e1

File tree

8 files changed

+328
-199
lines changed

8 files changed

+328
-199
lines changed

pytensor/link/jax/dispatch/scan.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def scan(*outer_inputs):
6060
mit_mot_init,
6161
mit_sot_init,
6262
sit_sot_init,
63-
op.outer_shared(outer_inputs),
63+
op.outer_untraced_sit_sot(outer_inputs),
6464
op.outer_non_seqs(outer_inputs),
6565
) # JAX `init`
6666

@@ -118,7 +118,7 @@ def inner_func_outs_to_jax_outs(
118118
):
119119
"""Convert inner_scan_func outputs into format expected by JAX scan.
120120
121-
old_carry + (MIT-SOT_outs, SIT-SOT_outs, NIT-SOT_outs, shared_outs) -> (new_carry, ys)
121+
old_carry + (MIT-SOT_outs, SIT-SOT_outs, NIT-SOT_outs, untraced_SIT-SOT_outs) -> (new_carry, ys)
122122
"""
123123
(
124124
i,
@@ -133,7 +133,7 @@ def inner_func_outs_to_jax_outs(
133133
new_mit_sot_vals = op.inner_mitsot_outs(inner_scan_outs)
134134
new_sit_sot = op.inner_sitsot_outs(inner_scan_outs)
135135
new_nit_sot = op.inner_nitsot_outs(inner_scan_outs)
136-
new_shared = op.inner_shared_outs(inner_scan_outs)
136+
new_shared = op.inner_untraced_sit_sot_outs(inner_scan_outs)
137137

138138
# New carry for next step
139139
# Update MIT-MOT buffer at positions indicated by output taps

pytensor/link/numba/dispatch/scan.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
108108
outer_in_mit_sot_names = op.outer_mitsot(outer_in_names)
109109
outer_in_sit_sot_names = op.outer_sitsot(outer_in_names)
110110
outer_in_nit_sot_names = op.outer_nitsot(outer_in_names)
111-
outer_in_shared_names = op.outer_shared(outer_in_names)
111+
outer_in_shared_names = op.outer_untraced_sit_sot(outer_in_names)
112112
outer_in_non_seqs_names = op.outer_non_seqs(outer_in_names)
113113

114114
# These are all the outer-input names that have produce outputs/have output
@@ -204,11 +204,9 @@ def add_inner_in_expr(
204204

205205
# Inner-outputs consist of:
206206
# mit-mot-outputs + mit-sot-outputs + sit-sot-outputs + nit-sots +
207-
# shared-outputs [+ while-condition]
207+
# untraced-sit-sot-outputs [+ while-condition]
208208
inner_output_names = [f"inner_out_{i}" for i in range(len(op.inner_outputs))]
209209

210-
# inner_out_shared_names = op.inner_shared_outs(inner_output_names)
211-
212210
# The assignment statements that copy inner-outputs into the outer-outputs
213211
# storage
214212
inner_out_to_outer_in_stmts: list[str] = []

pytensor/scan/basic.py

Lines changed: 76 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from pytensor.graph.op import get_test_value
1212
from pytensor.graph.replace import clone_replace
1313
from pytensor.graph.traversal import explicit_graph_inputs
14+
from pytensor.graph.type import HasShape
1415
from pytensor.graph.utils import MissingInputError, TestValueError
1516
from pytensor.scan.op import Scan, ScanInfo
1617
from pytensor.scan.utils import expand_empty, safe_new, until
@@ -706,6 +707,12 @@ def wrap_into_list(x):
706707
sit_sot_inner_outputs = []
707708
sit_sot_rightOrder = []
708709

710+
n_untraced_sit_sot_outs = 0
711+
untraced_sit_sot_scan_inputs = []
712+
untraced_sit_sot_inner_inputs = []
713+
untraced_sit_sot_inner_outputs = []
714+
untraced_sit_sot_rightOrder = []
715+
709716
# go through outputs picking up time slices as needed
710717
for i, init_out in enumerate(outs_info):
711718
# Note that our convention dictates that if an output uses
@@ -741,17 +748,35 @@ def wrap_into_list(x):
741748
# We need now to allocate space for storing the output and copy
742749
# the initial state over. We do this using the expand function
743750
# defined in scan utils
744-
sit_sot_scan_inputs.append(
745-
expand_empty(
746-
shape_padleft(actual_arg),
747-
actual_n_steps,
751+
if isinstance(actual_arg.type, HasShape):
752+
sit_sot_scan_inputs.append(
753+
expand_empty(
754+
shape_padleft(actual_arg),
755+
actual_n_steps,
756+
)
748757
)
749-
)
758+
sit_sot_inner_slices.append(actual_arg)
759+
760+
sit_sot_inner_inputs.append(arg)
761+
sit_sot_rightOrder.append(i)
762+
n_sit_sot += 1
763+
else:
764+
# Assume variables without shape cannot be stacked (e.g., RNG variables)
765+
# Because this is new, issue a warning to inform the user, except for RNG, which were the main reason for this feature
766+
from pytensor.tensor.random.type import RandomType
750767

751-
sit_sot_inner_slices.append(actual_arg)
752-
sit_sot_inner_inputs.append(arg)
753-
sit_sot_rightOrder.append(i)
754-
n_sit_sot += 1
768+
if not isinstance(arg.type, RandomType):
769+
warnings.warn(
770+
(
771+
f"Output {actual_arg} (index {i}) with type {actual_arg.type} will be treated as untraced variable in scan. "
772+
"Only the last value will be returned, not the entire sequence."
773+
),
774+
UserWarning,
775+
)
776+
untraced_sit_sot_scan_inputs.append(actual_arg)
777+
untraced_sit_sot_inner_inputs.append(arg)
778+
n_untraced_sit_sot_outs += 1
779+
untraced_sit_sot_rightOrder.append(i)
755780

756781
elif init_out.get("taps", None):
757782
if np.any(np.array(init_out.get("taps", [])) > 0):
@@ -802,9 +827,10 @@ def wrap_into_list(x):
802827
# a map); in that case we do not have to do anything ..
803828

804829
# Re-order args
805-
max_mit_sot = np.max([-1, *mit_sot_rightOrder]) + 1
806-
max_sit_sot = np.max([-1, *sit_sot_rightOrder]) + 1
807-
n_elems = np.max([max_mit_sot, max_sit_sot])
830+
max_mit_sot = max(mit_sot_rightOrder, default=-1) + 1
831+
max_sit_sot = max(sit_sot_rightOrder, default=-1) + 1
832+
max_untraced_sit_sot_outs = max(untraced_sit_sot_rightOrder, default=-1) + 1
833+
n_elems = np.max((max_mit_sot, max_sit_sot, max_untraced_sit_sot_outs))
808834
_ordered_args = [[] for x in range(n_elems)]
809835
offset = 0
810836
for idx in range(n_mit_sot):
@@ -825,6 +851,11 @@ def wrap_into_list(x):
825851
else:
826852
_ordered_args[sit_sot_rightOrder[idx]] = [sit_sot_inner_inputs[idx]]
827853

854+
for idx in range(n_untraced_sit_sot_outs):
855+
_ordered_args[untraced_sit_sot_rightOrder[idx]] = [
856+
untraced_sit_sot_inner_inputs[idx]
857+
]
858+
828859
ordered_args = list(chain.from_iterable(_ordered_args))
829860
if single_step_requested:
830861
args = inner_slices + ordered_args + non_seqs
@@ -939,18 +970,19 @@ def wrap_into_list(x):
939970
if "taps" in out and out["taps"] != [-1]:
940971
mit_sot_inner_outputs.append(outputs[i])
941972

942-
# Step 5.2 Outputs with tap equal to -1
973+
# Step 5.2 Outputs with tap equal to -1 (traced and untraced)
943974
for i, out in enumerate(outs_info):
944975
if "taps" in out and out["taps"] == [-1]:
945-
sit_sot_inner_outputs.append(outputs[i])
976+
output = outputs[i]
977+
if isinstance(output.type, HasShape):
978+
sit_sot_inner_outputs.append(output)
979+
else:
980+
untraced_sit_sot_inner_outputs.append(output)
946981

947982
# Step 5.3 Outputs that correspond to update rules of shared variables
948-
inner_replacements = {}
949-
n_shared_outs = 0
950-
shared_scan_inputs = []
951-
shared_inner_inputs = []
952-
shared_inner_outputs = []
983+
# This whole special logic for shared variables is deprecated
953984
sit_sot_shared = []
985+
inner_replacements = {}
954986
no_update_shared_inputs = []
955987
for input in dummy_inputs:
956988
if not isinstance(input.variable, SharedVariable):
@@ -1003,10 +1035,10 @@ def wrap_into_list(x):
10031035
sit_sot_shared.append(input.variable)
10041036

10051037
else:
1006-
shared_inner_inputs.append(new_var)
1007-
shared_scan_inputs.append(input.variable)
1008-
shared_inner_outputs.append(input.update)
1009-
n_shared_outs += 1
1038+
untraced_sit_sot_inner_inputs.append(new_var)
1039+
untraced_sit_sot_scan_inputs.append(input.variable)
1040+
untraced_sit_sot_inner_outputs.append(input.update)
1041+
n_untraced_sit_sot_outs += 1
10101042
else:
10111043
no_update_shared_inputs.append(input)
10121044

@@ -1071,7 +1103,7 @@ def wrap_into_list(x):
10711103
+ mit_mot_inner_inputs
10721104
+ mit_sot_inner_inputs
10731105
+ sit_sot_inner_inputs
1074-
+ shared_inner_inputs
1106+
+ untraced_sit_sot_inner_inputs
10751107
+ other_shared_inner_args
10761108
+ other_inner_args
10771109
)
@@ -1081,7 +1113,7 @@ def wrap_into_list(x):
10811113
+ mit_sot_inner_outputs
10821114
+ sit_sot_inner_outputs
10831115
+ nit_sot_inner_outputs
1084-
+ shared_inner_outputs
1116+
+ untraced_sit_sot_inner_outputs
10851117
)
10861118
if condition is not None:
10871119
inner_outs.append(condition)
@@ -1101,7 +1133,7 @@ def wrap_into_list(x):
11011133
mit_mot_out_slices=tuple(tuple(v) for v in mit_mot_out_slices),
11021134
mit_sot_in_slices=tuple(tuple(v) for v in mit_sot_tap_array),
11031135
sit_sot_in_slices=tuple((-1,) for x in range(n_sit_sot)),
1104-
n_shared_outs=n_shared_outs,
1136+
n_untraced_sit_sot_outs=n_untraced_sit_sot_outs,
11051137
n_nit_sot=n_nit_sot,
11061138
n_non_seqs=len(other_shared_inner_args) + len(other_inner_args),
11071139
as_while=as_while,
@@ -1127,7 +1159,7 @@ def wrap_into_list(x):
11271159
+ mit_mot_scan_inputs
11281160
+ mit_sot_scan_inputs
11291161
+ sit_sot_scan_inputs
1130-
+ shared_scan_inputs
1162+
+ untraced_sit_sot_scan_inputs
11311163
+ [actual_n_steps for x in range(n_nit_sot)]
11321164
+ other_shared_scan_args
11331165
+ other_scan_args
@@ -1173,13 +1205,26 @@ def remove_dimensions(outs, offsets=None):
11731205
nit_sot_outs = remove_dimensions(scan_outs[offset : offset + n_nit_sot])
11741206

11751207
offset += n_nit_sot
1176-
for idx, update_rule in enumerate(scan_outs[offset : offset + n_shared_outs]):
1177-
update_map[shared_scan_inputs[idx]] = update_rule
11781208

1179-
_scan_out_list = mit_sot_outs + sit_sot_outs + nit_sot_outs
1209+
# Support for explicit untraced sit_sot
1210+
n_explicit_untraced_sit_sot_outs = len(untraced_sit_sot_rightOrder)
1211+
untraced_sit_sot_outs = scan_outs[
1212+
offset : offset + n_explicit_untraced_sit_sot_outs
1213+
]
1214+
1215+
offset += n_explicit_untraced_sit_sot_outs
1216+
for idx, update_rule in enumerate(scan_outs[offset:]):
1217+
update_map[untraced_sit_sot_scan_inputs[idx]] = update_rule
1218+
1219+
_scan_out_list = mit_sot_outs + sit_sot_outs + nit_sot_outs + untraced_sit_sot_outs
11801220
# Step 10. I need to reorder the outputs to be in the order expected by
11811221
# the user
1182-
rightOrder = mit_sot_rightOrder + sit_sot_rightOrder + nit_sot_rightOrder
1222+
rightOrder = (
1223+
mit_sot_rightOrder
1224+
+ sit_sot_rightOrder
1225+
+ untraced_sit_sot_rightOrder
1226+
+ nit_sot_rightOrder
1227+
)
11831228
scan_out_list = [None] * len(rightOrder)
11841229
for idx, pos in enumerate(rightOrder):
11851230
if pos >= 0:

0 commit comments

Comments
 (0)