1+ import typing
12import warnings
23from itertools import chain
34
1112from pytensor .graph .op import get_test_value
1213from pytensor .graph .replace import clone_replace
1314from pytensor .graph .traversal import explicit_graph_inputs
15+ from pytensor .graph .type import HasShape
1416from pytensor .graph .utils import MissingInputError , TestValueError
1517from pytensor .scan .op import Scan , ScanInfo
1618from pytensor .scan .utils import expand_empty , safe_new , until
2224from pytensor .updates import OrderedUpdates
2325
2426
27+ if typing .TYPE_CHECKING :
28+ from pytensor .tensor .type import TensorVariable
29+
30+
2531def get_updates_and_outputs (ls ):
2632 """Recognize and order the updates, outputs, and stopping condition for a `Scan`.
2733
@@ -469,7 +475,7 @@ def wrap_into_list(x):
469475
470476 # Make sure we get rid of numpy arrays or ints or anything like that
471477 # passed as inputs to scan
472- non_seqs = []
478+ non_seqs : list [ Variable ] = []
473479 for elem in wrap_into_list (non_sequences ):
474480 if not isinstance (elem , Variable ):
475481 non_seqs .append (pt .as_tensor_variable (elem ))
@@ -685,10 +691,10 @@ def wrap_into_list(x):
685691
686692 # MIT_MOT -- not provided by the user only by the grad function
687693 n_mit_mot = 0
688- mit_mot_scan_inputs = []
689- mit_mot_inner_inputs = []
690- mit_mot_inner_outputs = []
691- mit_mot_out_slices = []
694+ mit_mot_scan_inputs : list [ TensorVariable ] = []
695+ mit_mot_inner_inputs : list [ TensorVariable ] = []
696+ mit_mot_inner_outputs : list [ TensorVariable ] = []
697+ mit_mot_out_slices : list [ TensorVariable ] = []
692698
693699 # SIT_SOT -- provided by the user
694700 n_mit_sot = 0
@@ -706,6 +712,12 @@ def wrap_into_list(x):
706712 sit_sot_inner_outputs = []
707713 sit_sot_rightOrder = []
708714
715+ n_untraced_sit_sot_outs = 0
716+ untraced_sit_sot_scan_inputs = []
717+ untraced_sit_sot_inner_inputs = []
718+ untraced_sit_sot_inner_outputs = []
719+ untraced_sit_sot_rightOrder = []
720+
709721 # go through outputs picking up time slices as needed
710722 for i , init_out in enumerate (outs_info ):
711723 # Note that our convention dictates that if an output uses
@@ -741,17 +753,35 @@ def wrap_into_list(x):
741753 # We need now to allocate space for storing the output and copy
742754 # the initial state over. We do this using the expand function
743755 # defined in scan utils
744- sit_sot_scan_inputs .append (
745- expand_empty (
746- shape_padleft (actual_arg ),
747- actual_n_steps ,
756+ if isinstance (actual_arg .type , HasShape ):
757+ sit_sot_scan_inputs .append (
758+ expand_empty (
759+ shape_padleft (actual_arg ),
760+ actual_n_steps ,
761+ )
748762 )
749- )
763+ sit_sot_inner_slices . append ( actual_arg )
750764
751- sit_sot_inner_slices .append (actual_arg )
752- sit_sot_inner_inputs .append (arg )
753- sit_sot_rightOrder .append (i )
754- n_sit_sot += 1
765+ sit_sot_inner_inputs .append (arg )
766+ sit_sot_rightOrder .append (i )
767+ n_sit_sot += 1
768+ else :
769+ # Assume variables without shape cannot be stacked (e.g., RNG variables)
770+ # Because this is new, issue a warning to inform the user, except for RNG, which were the main reason for this feature
771+ from pytensor .tensor .random .type import RandomType
772+
773+ if not isinstance (arg .type , RandomType ):
774+ warnings .warn (
775+ (
776+ f"Output { actual_arg } (index { i } ) with type { actual_arg .type } will be treated as untraced variable in scan. "
777+ "Only the last value will be returned, not the entire sequence."
778+ ),
779+ UserWarning ,
780+ )
781+ untraced_sit_sot_scan_inputs .append (actual_arg )
782+ untraced_sit_sot_inner_inputs .append (arg )
783+ n_untraced_sit_sot_outs += 1
784+ untraced_sit_sot_rightOrder .append (i )
755785
756786 elif init_out .get ("taps" , None ):
757787 if np .any (np .array (init_out .get ("taps" , [])) > 0 ):
@@ -802,10 +832,11 @@ def wrap_into_list(x):
802832 # a map); in that case we do not have to do anything ..
803833
804834 # Re-order args
805- max_mit_sot = np .max ([- 1 , * mit_sot_rightOrder ]) + 1
806- max_sit_sot = np .max ([- 1 , * sit_sot_rightOrder ]) + 1
807- n_elems = np .max ([max_mit_sot , max_sit_sot ])
808- _ordered_args = [[] for x in range (n_elems )]
835+ max_mit_sot = max (mit_sot_rightOrder , default = - 1 ) + 1
836+ max_sit_sot = max (sit_sot_rightOrder , default = - 1 ) + 1
837+ max_untraced_sit_sot_outs = max (untraced_sit_sot_rightOrder , default = - 1 ) + 1
838+ n_elems = np .max ((max_mit_sot , max_sit_sot , max_untraced_sit_sot_outs ))
839+ _ordered_args : list [list [Variable ]] = [[] for x in range (n_elems )]
809840 offset = 0
810841 for idx in range (n_mit_sot ):
811842 n_inputs = len (mit_sot_tap_array [idx ])
@@ -825,6 +856,11 @@ def wrap_into_list(x):
825856 else :
826857 _ordered_args [sit_sot_rightOrder [idx ]] = [sit_sot_inner_inputs [idx ]]
827858
859+ for idx in range (n_untraced_sit_sot_outs ):
860+ _ordered_args [untraced_sit_sot_rightOrder [idx ]] = [
861+ untraced_sit_sot_inner_inputs [idx ]
862+ ]
863+
828864 ordered_args = list (chain .from_iterable (_ordered_args ))
829865 if single_step_requested :
830866 args = inner_slices + ordered_args + non_seqs
@@ -939,18 +975,19 @@ def wrap_into_list(x):
939975 if "taps" in out and out ["taps" ] != [- 1 ]:
940976 mit_sot_inner_outputs .append (outputs [i ])
941977
942- # Step 5.2 Outputs with tap equal to -1
978+ # Step 5.2 Outputs with tap equal to -1 (traced and untraced)
943979 for i , out in enumerate (outs_info ):
944980 if "taps" in out and out ["taps" ] == [- 1 ]:
945- sit_sot_inner_outputs .append (outputs [i ])
981+ output = outputs [i ]
982+ if isinstance (output .type , HasShape ):
983+ sit_sot_inner_outputs .append (output )
984+ else :
985+ untraced_sit_sot_inner_outputs .append (output )
946986
947987 # Step 5.3 Outputs that correspond to update rules of shared variables
988+ # This whole special logic for shared variables is deprecated
989+ sit_sot_shared : list [Variable ] = []
948990 inner_replacements = {}
949- n_shared_outs = 0
950- shared_scan_inputs = []
951- shared_inner_inputs = []
952- shared_inner_outputs = []
953- sit_sot_shared = []
954991 no_update_shared_inputs = []
955992 for input in dummy_inputs :
956993 if not isinstance (input .variable , SharedVariable ):
@@ -976,8 +1013,8 @@ def wrap_into_list(x):
9761013
9771014 new_var = safe_new (input .variable )
9781015
979- if getattr ( input .variable , " name" , None ) is not None :
980- new_var .name = input .variable .name + " _copy"
1016+ if input .variable . name is not None :
1017+ new_var .name = f" { input .variable .name } _copy"
9811018
9821019 inner_replacements [input .variable ] = new_var
9831020
@@ -1003,10 +1040,10 @@ def wrap_into_list(x):
10031040 sit_sot_shared .append (input .variable )
10041041
10051042 else :
1006- shared_inner_inputs .append (new_var )
1007- shared_scan_inputs .append (input .variable )
1008- shared_inner_outputs .append (input .update )
1009- n_shared_outs += 1
1043+ untraced_sit_sot_inner_inputs .append (new_var )
1044+ untraced_sit_sot_scan_inputs .append (input .variable )
1045+ untraced_sit_sot_inner_outputs .append (input .update )
1046+ n_untraced_sit_sot_outs += 1
10101047 else :
10111048 no_update_shared_inputs .append (input )
10121049
@@ -1071,7 +1108,7 @@ def wrap_into_list(x):
10711108 + mit_mot_inner_inputs
10721109 + mit_sot_inner_inputs
10731110 + sit_sot_inner_inputs
1074- + shared_inner_inputs
1111+ + untraced_sit_sot_inner_inputs
10751112 + other_shared_inner_args
10761113 + other_inner_args
10771114 )
@@ -1081,7 +1118,7 @@ def wrap_into_list(x):
10811118 + mit_sot_inner_outputs
10821119 + sit_sot_inner_outputs
10831120 + nit_sot_inner_outputs
1084- + shared_inner_outputs
1121+ + untraced_sit_sot_inner_outputs
10851122 )
10861123 if condition is not None :
10871124 inner_outs .append (condition )
@@ -1101,7 +1138,7 @@ def wrap_into_list(x):
11011138 mit_mot_out_slices = tuple (tuple (v ) for v in mit_mot_out_slices ),
11021139 mit_sot_in_slices = tuple (tuple (v ) for v in mit_sot_tap_array ),
11031140 sit_sot_in_slices = tuple ((- 1 ,) for x in range (n_sit_sot )),
1104- n_shared_outs = n_shared_outs ,
1141+ n_untraced_sit_sot_outs = n_untraced_sit_sot_outs ,
11051142 n_nit_sot = n_nit_sot ,
11061143 n_non_seqs = len (other_shared_inner_args ) + len (other_inner_args ),
11071144 as_while = as_while ,
@@ -1127,7 +1164,7 @@ def wrap_into_list(x):
11271164 + mit_mot_scan_inputs
11281165 + mit_sot_scan_inputs
11291166 + sit_sot_scan_inputs
1130- + shared_scan_inputs
1167+ + untraced_sit_sot_scan_inputs
11311168 + [actual_n_steps for x in range (n_nit_sot )]
11321169 + other_shared_scan_args
11331170 + other_scan_args
@@ -1173,13 +1210,26 @@ def remove_dimensions(outs, offsets=None):
11731210 nit_sot_outs = remove_dimensions (scan_outs [offset : offset + n_nit_sot ])
11741211
11751212 offset += n_nit_sot
1176- for idx , update_rule in enumerate (scan_outs [offset : offset + n_shared_outs ]):
1177- update_map [shared_scan_inputs [idx ]] = update_rule
11781213
1179- _scan_out_list = mit_sot_outs + sit_sot_outs + nit_sot_outs
1214+ # Support for explicit untraced sit_sot
1215+ n_explicit_untraced_sit_sot_outs = len (untraced_sit_sot_rightOrder )
1216+ untraced_sit_sot_outs = scan_outs [
1217+ offset : offset + n_explicit_untraced_sit_sot_outs
1218+ ]
1219+
1220+ offset += n_explicit_untraced_sit_sot_outs
1221+ for idx , update_rule in enumerate (scan_outs [offset :]):
1222+ update_map [untraced_sit_sot_scan_inputs [idx ]] = update_rule
1223+
1224+ _scan_out_list = mit_sot_outs + sit_sot_outs + nit_sot_outs + untraced_sit_sot_outs
11801225 # Step 10. I need to reorder the outputs to be in the order expected by
11811226 # the user
1182- rightOrder = mit_sot_rightOrder + sit_sot_rightOrder + nit_sot_rightOrder
1227+ rightOrder = (
1228+ mit_sot_rightOrder
1229+ + sit_sot_rightOrder
1230+ + untraced_sit_sot_rightOrder
1231+ + nit_sot_rightOrder
1232+ )
11831233 scan_out_list = [None ] * len (rightOrder )
11841234 for idx , pos in enumerate (rightOrder ):
11851235 if pos >= 0 :
0 commit comments