1111from pytensor .graph .op import get_test_value
1212from pytensor .graph .replace import clone_replace
1313from pytensor .graph .traversal import explicit_graph_inputs
14+ from pytensor .graph .type import HasShape
1415from pytensor .graph .utils import MissingInputError , TestValueError
1516from pytensor .scan .op import Scan , ScanInfo
1617from pytensor .scan .utils import expand_empty , safe_new , until
@@ -706,6 +707,12 @@ def wrap_into_list(x):
706707 sit_sot_inner_outputs = []
707708 sit_sot_rightOrder = []
708709
710+ n_untraced_sit_sot_outs = 0
711+ untraced_sit_sot_scan_inputs = []
712+ untraced_sit_sot_inner_inputs = []
713+ untraced_sit_sot_inner_outputs = []
714+ untraced_sit_sot_rightOrder = []
715+
709716 # go through outputs picking up time slices as needed
710717 for i , init_out in enumerate (outs_info ):
711718 # Note that our convention dictates that if an output uses
@@ -741,17 +748,35 @@ def wrap_into_list(x):
741748 # We need now to allocate space for storing the output and copy
742749 # the initial state over. We do this using the expand function
743750 # defined in scan utils
744- sit_sot_scan_inputs .append (
745- expand_empty (
746- shape_padleft (actual_arg ),
747- actual_n_steps ,
751+ if isinstance (actual_arg .type , HasShape ):
752+ sit_sot_scan_inputs .append (
753+ expand_empty (
754+ shape_padleft (actual_arg ),
755+ actual_n_steps ,
756+ )
748757 )
749- )
758+ sit_sot_inner_slices .append (actual_arg )
759+
760+ sit_sot_inner_inputs .append (arg )
761+ sit_sot_rightOrder .append (i )
762+ n_sit_sot += 1
763+ else :
764+ # Assume variables without shape cannot be stacked (e.g., RNG variables)
765+ # Because this is new, issue a warning to inform the user, except for RNG, which were the main reason for this feature
766+ from pytensor .tensor .random .type import RandomType
750767
751- sit_sot_inner_slices .append (actual_arg )
752- sit_sot_inner_inputs .append (arg )
753- sit_sot_rightOrder .append (i )
754- n_sit_sot += 1
768+ if not isinstance (arg .type , RandomType ):
769+ warnings .warn (
770+ (
771+ f"Output { actual_arg } (index { i } ) with type { actual_arg .type } will be treated as untraced variable in scan. "
772+ "Only the last value will be returned, not the entire sequence."
773+ ),
774+ UserWarning ,
775+ )
776+ untraced_sit_sot_scan_inputs .append (actual_arg )
777+ untraced_sit_sot_inner_inputs .append (arg )
778+ n_untraced_sit_sot_outs += 1
779+ untraced_sit_sot_rightOrder .append (i )
755780
756781 elif init_out .get ("taps" , None ):
757782 if np .any (np .array (init_out .get ("taps" , [])) > 0 ):
@@ -802,9 +827,10 @@ def wrap_into_list(x):
802827 # a map); in that case we do not have to do anything ..
803828
804829 # Re-order args
805- max_mit_sot = np .max ([- 1 , * mit_sot_rightOrder ]) + 1
806- max_sit_sot = np .max ([- 1 , * sit_sot_rightOrder ]) + 1
807- n_elems = np .max ([max_mit_sot , max_sit_sot ])
830+ max_mit_sot = max (mit_sot_rightOrder , default = - 1 ) + 1
831+ max_sit_sot = max (sit_sot_rightOrder , default = - 1 ) + 1
832+ max_untraced_sit_sot_outs = max (untraced_sit_sot_rightOrder , default = - 1 ) + 1
833+ n_elems = np .max ((max_mit_sot , max_sit_sot , max_untraced_sit_sot_outs ))
808834 _ordered_args = [[] for x in range (n_elems )]
809835 offset = 0
810836 for idx in range (n_mit_sot ):
@@ -825,6 +851,11 @@ def wrap_into_list(x):
825851 else :
826852 _ordered_args [sit_sot_rightOrder [idx ]] = [sit_sot_inner_inputs [idx ]]
827853
854+ for idx in range (n_untraced_sit_sot_outs ):
855+ _ordered_args [untraced_sit_sot_rightOrder [idx ]] = [
856+ untraced_sit_sot_inner_inputs [idx ]
857+ ]
858+
828859 ordered_args = list (chain .from_iterable (_ordered_args ))
829860 if single_step_requested :
830861 args = inner_slices + ordered_args + non_seqs
@@ -939,18 +970,19 @@ def wrap_into_list(x):
939970 if "taps" in out and out ["taps" ] != [- 1 ]:
940971 mit_sot_inner_outputs .append (outputs [i ])
941972
942- # Step 5.2 Outputs with tap equal to -1
973+ # Step 5.2 Outputs with tap equal to -1 (traced and untraced)
943974 for i , out in enumerate (outs_info ):
944975 if "taps" in out and out ["taps" ] == [- 1 ]:
945- sit_sot_inner_outputs .append (outputs [i ])
976+ output = outputs [i ]
977+ if isinstance (output .type , HasShape ):
978+ sit_sot_inner_outputs .append (output )
979+ else :
980+ untraced_sit_sot_inner_outputs .append (output )
946981
947982 # Step 5.3 Outputs that correspond to update rules of shared variables
948- inner_replacements = {}
949- n_shared_outs = 0
950- shared_scan_inputs = []
951- shared_inner_inputs = []
952- shared_inner_outputs = []
983+ # This whole special logic for shared variables is deprecated
953984 sit_sot_shared = []
985+ inner_replacements = {}
954986 no_update_shared_inputs = []
955987 for input in dummy_inputs :
956988 if not isinstance (input .variable , SharedVariable ):
@@ -1003,10 +1035,10 @@ def wrap_into_list(x):
10031035 sit_sot_shared .append (input .variable )
10041036
10051037 else :
1006- shared_inner_inputs .append (new_var )
1007- shared_scan_inputs .append (input .variable )
1008- shared_inner_outputs .append (input .update )
1009- n_shared_outs += 1
1038+ untraced_sit_sot_inner_inputs .append (new_var )
1039+ untraced_sit_sot_scan_inputs .append (input .variable )
1040+ untraced_sit_sot_inner_outputs .append (input .update )
1041+ n_untraced_sit_sot_outs += 1
10101042 else :
10111043 no_update_shared_inputs .append (input )
10121044
@@ -1071,7 +1103,7 @@ def wrap_into_list(x):
10711103 + mit_mot_inner_inputs
10721104 + mit_sot_inner_inputs
10731105 + sit_sot_inner_inputs
1074- + shared_inner_inputs
1106+ + untraced_sit_sot_inner_inputs
10751107 + other_shared_inner_args
10761108 + other_inner_args
10771109 )
@@ -1081,7 +1113,7 @@ def wrap_into_list(x):
10811113 + mit_sot_inner_outputs
10821114 + sit_sot_inner_outputs
10831115 + nit_sot_inner_outputs
1084- + shared_inner_outputs
1116+ + untraced_sit_sot_inner_outputs
10851117 )
10861118 if condition is not None :
10871119 inner_outs .append (condition )
@@ -1101,7 +1133,7 @@ def wrap_into_list(x):
11011133 mit_mot_out_slices = tuple (tuple (v ) for v in mit_mot_out_slices ),
11021134 mit_sot_in_slices = tuple (tuple (v ) for v in mit_sot_tap_array ),
11031135 sit_sot_in_slices = tuple ((- 1 ,) for x in range (n_sit_sot )),
1104- n_shared_outs = n_shared_outs ,
1136+ n_untraced_sit_sot_outs = n_untraced_sit_sot_outs ,
11051137 n_nit_sot = n_nit_sot ,
11061138 n_non_seqs = len (other_shared_inner_args ) + len (other_inner_args ),
11071139 as_while = as_while ,
@@ -1127,7 +1159,7 @@ def wrap_into_list(x):
11271159 + mit_mot_scan_inputs
11281160 + mit_sot_scan_inputs
11291161 + sit_sot_scan_inputs
1130- + shared_scan_inputs
1162+ + untraced_sit_sot_scan_inputs
11311163 + [actual_n_steps for x in range (n_nit_sot )]
11321164 + other_shared_scan_args
11331165 + other_scan_args
@@ -1173,13 +1205,26 @@ def remove_dimensions(outs, offsets=None):
11731205 nit_sot_outs = remove_dimensions (scan_outs [offset : offset + n_nit_sot ])
11741206
11751207 offset += n_nit_sot
1176- for idx , update_rule in enumerate (scan_outs [offset : offset + n_shared_outs ]):
1177- update_map [shared_scan_inputs [idx ]] = update_rule
11781208
1179- _scan_out_list = mit_sot_outs + sit_sot_outs + nit_sot_outs
1209+ # Support for explicit untraced sit_sot
1210+ n_explicit_untraced_sit_sot_outs = len (untraced_sit_sot_rightOrder )
1211+ untraced_sit_sot_outs = scan_outs [
1212+ offset : offset + n_explicit_untraced_sit_sot_outs
1213+ ]
1214+
1215+ offset += n_explicit_untraced_sit_sot_outs
1216+ for idx , update_rule in enumerate (scan_outs [offset :]):
1217+ update_map [untraced_sit_sot_scan_inputs [idx ]] = update_rule
1218+
1219+ _scan_out_list = mit_sot_outs + sit_sot_outs + nit_sot_outs + untraced_sit_sot_outs
11801220 # Step 10. I need to reorder the outputs to be in the order expected by
11811221 # the user
1182- rightOrder = mit_sot_rightOrder + sit_sot_rightOrder + nit_sot_rightOrder
1222+ rightOrder = (
1223+ mit_sot_rightOrder
1224+ + sit_sot_rightOrder
1225+ + untraced_sit_sot_rightOrder
1226+ + nit_sot_rightOrder
1227+ )
11831228 scan_out_list = [None ] * len (rightOrder )
11841229 for idx , pos in enumerate (rightOrder ):
11851230 if pos >= 0 :
0 commit comments