11from abc import ABC
2+ from functools import partial
23
34import numpy as np
45import pytensor
910from pytensor .raise_op import Assert
1011from pytensor .tensor import TensorVariable
1112from pytensor .tensor .slinalg import solve_triangular
12- from pytensor .graph .replace import vectorize_graph
1313
1414from pymc_extras .statespace .filters .utilities import (
1515 quad_form_sym ,
1616 split_vars_into_seq_and_nonseq ,
1717 stabilize ,
1818)
19- from pymc_extras .statespace .utils .constants import JITTER_DEFAULT , MISSING_FILL
19+ from pymc_extras .statespace .utils .constants import JITTER_DEFAULT , MISSING_FILL , ALL_KF_OUTPUT_NAMES
2020
2121MVN_CONST = pt .log (2 * pt .constant (np .pi , dtype = "float64" ))
2222PARAM_NAMES = ["c" , "d" , "T" , "Z" , "R" , "H" , "Q" ]
@@ -65,22 +65,56 @@ def check_params(self, data, a0, P0, c, d, T, Z, R, H, Q):
6565 """
6666 return data , a0 , P0 , c , d , T , Z , R , H , Q
6767
68- def has_batched_input (self , data , a0 , P0 , c , d , T , Z , R , H , Q ):
69- """
70- Check if any of the inputs are batched.
71- """
72- return any (x .ndim > CORE_NDIM [i ] for i , x in enumerate ([data , a0 , P0 , c , d , T , Z , R , H , Q ]))
73-
74- def get_dummy_core_inputs (self , data , a0 , P0 , c , d , T , Z , R , H , Q ):
75- """
76- Get dummy inputs for the core parameters.
77- """
78- out = []
79- for x , core_ndim in zip ([data , a0 , P0 , c , d , T , Z , R , H , Q ], CORE_NDIM ):
80- out .append (
81- pt .tensor (f"{ x .name } _core_case" , dtype = x .dtype , shape = x .type .shape [- core_ndim :])
82- )
83- return out
68+ def _make_gufunc_signature (self , inputs ):
69+ states = "s"
70+ obs = "p"
71+ exog = "r"
72+ time = "t"
73+
74+ matrix_to_shape = {
75+ "data" : (time , obs ),
76+ "a0" : (states ,),
77+ "x0" : (states ,),
78+ "P0" : (states , states ),
79+ "c" : (states ,),
80+ "d" : (obs ,),
81+ "T" : (states , states ),
82+ "Z" : (obs , states ),
83+ "R" : (states , exog ),
84+ "H" : (obs , obs ),
85+ "Q" : (exog , exog ),
86+ "filtered_states" : (time , states ),
87+ "filtered_covariances" : (time , states , states ),
88+ "predicted_states" : (time , states ),
89+ "predicted_covariances" : (time , states , states ),
90+ "observed_states" : (time , obs ),
91+ "observed_covariances" : (time , obs , obs ),
92+ "smoothed_states" : (time , states ),
93+ "smoothed_covariances" : (time , states , states ),
94+ "loglike_obs" : (time ,),
95+ }
96+ input_shapes = []
97+ output_shapes = []
98+
99+ for matrix in inputs :
100+ name = matrix .name
101+ input_shapes .append (matrix_to_shape [name ])
102+
103+ for name in [
104+ "filtered_states" ,
105+ "predicted_states" ,
106+ "smoothed_states" ,
107+ "filtered_covariances" ,
108+ "predicted_covariances" ,
109+ "smoothed_covariances" ,
110+ "loglike_obs" ,
111+ ]:
112+ output_shapes .append (matrix_to_shape [name ])
113+
114+ input_signature = "," .join (["(" + "," .join (shapes ) + ")" for shapes in input_shapes ])
115+ output_signature = "," .join (["(" + "," .join (shapes ) + ")" for shapes in output_shapes ])
116+
117+ return f"{ input_signature } -> { output_signature } "
84118
85119 @staticmethod
86120 def add_check_on_time_varying_shapes (
@@ -150,7 +184,7 @@ def unpack_args(self, args) -> tuple:
150184
151185 return y , a0 , P0 , c , d , T , Z , R , H , Q
152186
153- def build_graph (
187+ def _build_graph (
154188 self ,
155189 data ,
156190 a0 ,
@@ -206,18 +240,13 @@ def build_graph(
206240
207241 self .missing_fill_value = missing_fill_value
208242 self .cov_jitter = cov_jitter
209- is_batched = self .has_batched_input (data , a0 , P0 , c , d , T , Z , R , H , Q )
210243
211244 [R_shape ] = constant_fold ([R .shape ], raise_not_constant = False )
212245 [Z_shape ] = constant_fold ([Z .shape ], raise_not_constant = False )
213246
214247 self .n_states , self .n_shocks = R_shape [- 2 :]
215248 self .n_endog = Z_shape [- 2 ]
216249
217- if is_batched :
218- batched_inputs = [data , a0 , P0 , c , d , T , Z , R , H , Q ]
219- data , a0 , P0 , c , d , T , Z , R , H , Q = self .get_dummy_core_inputs (* batched_inputs )
220-
221250 data , a0 , P0 , * params = self .check_params (data , a0 , P0 , c , d , T , Z , R , H , Q )
222251
223252 sequences , non_sequences , seq_names , non_seq_names = split_vars_into_seq_and_nonseq (
@@ -241,15 +270,47 @@ def build_graph(
241270
242271 filter_results = self ._postprocess_scan_results (results , a0 , P0 , n = data .type .shape [0 ])
243272
244- if is_batched :
245- vec_subs = dict (zip ([data , a0 , P0 , c , d , T , Z , R , H , Q ], batched_inputs ))
246- filter_results = vectorize_graph (filter_results , vec_subs )
247-
248273 if return_updates :
249274 return filter_results , updates
250275
251276 return filter_results
252277
278+ def build_graph (
279+ self ,
280+ data ,
281+ a0 ,
282+ P0 ,
283+ c ,
284+ d ,
285+ T ,
286+ Z ,
287+ R ,
288+ H ,
289+ Q ,
290+ mode = None ,
291+ return_updates = False ,
292+ missing_fill_value = None ,
293+ cov_jitter = None ,
294+ ) -> list [TensorVariable ] | tuple [list [TensorVariable ], dict ]:
295+ """
296+ Build the vectorized computation graph for the Kalman filter.
297+ """
298+ signature = self ._make_gufunc_signature (
299+ [data , a0 , P0 , c , d , T , Z , R , H , Q ],
300+ )
301+ fn = partial (
302+ self ._build_graph ,
303+ mode = mode ,
304+ return_updates = return_updates ,
305+ missing_fill_value = missing_fill_value ,
306+ cov_jitter = cov_jitter ,
307+ )
308+ filter_outputs = pt .vectorize (fn , signature = signature )(data , a0 , P0 , c , d , T , Z , R , H , Q )
309+ for output , name in zip (filter_outputs , ALL_KF_OUTPUT_NAMES ):
310+ output .name = name
311+
312+ return filter_outputs
313+
253314 def _postprocess_scan_results (self , results , a0 , P0 , n ) -> list [TensorVariable ]:
254315 """
255316 Transform the values returned by the Kalman Filter scan into a form expected by users. In particular:
0 commit comments