@@ -233,10 +233,9 @@ def __init__(
233233 self ._fit_coords : dict [str , Sequence [str ]] | None = None
234234 self ._fit_dims : dict [str , Sequence [str ]] | None = None
235235 self ._fit_data : pt .TensorVariable | None = None
236+ self ._fit_exog_data : dict [str , dict ] = {}
236237
237238 self ._needs_exog_data = None
238- self ._exog_names = []
239- self ._exog_data_info = {}
240239 self ._name_to_variable = {}
241240 self ._name_to_data = {}
242241
@@ -671,7 +670,7 @@ def _save_exogenous_data_info(self):
671670 pymc_mod = modelcontext (None )
672671 for data_name in self .data_names :
673672 data = pymc_mod [data_name ]
674- self ._exog_data_info [data_name ] = {
673+ self ._fit_exog_data [data_name ] = {
675674 "name" : data_name ,
676675 "value" : data .get_value (),
677676 "dims" : pymc_mod .named_vars_to_dims .get (data_name , None ),
@@ -685,7 +684,7 @@ def _insert_random_variables(self):
685684 --------
686685 .. code:: python
687686
688- ss_mod = pmss.BayesianSARIMA (order=(2, 0, 2), verbose=False, stationary_initialization=True)
687+ ss_mod = pmss.BayesianSARIMAX (order=(2, 0, 2), verbose=False, stationary_initialization=True)
689688 with pm.Model():
690689 x0 = pm.Normal('x0', size=ss_mod.k_states)
691690 ar_params = pm.Normal('ar_params', size=ss_mod.p)
@@ -1082,7 +1081,7 @@ def _kalman_filter_outputs_from_dummy_graph(
10821081
10831082 for name in self .data_names :
10841083 if name not in pm_mod :
1085- pm .Data (** self ._exog_data_info [name ])
1084+ pm .Data (** self ._fit_exog_data [name ])
10861085
10871086 self ._insert_data_variables ()
10881087
@@ -1229,7 +1228,7 @@ def _sample_conditional(
12291228 method = mvn_method ,
12301229 )
12311230
1232- obs_mu = (Z @ mu [..., None ]).squeeze (- 1 )
1231+ obs_mu = d + (Z @ mu [..., None ]).squeeze (- 1 )
12331232 obs_cov = Z @ cov @ pt .swapaxes (Z , - 2 , - 1 ) + H
12341233
12351234 SequenceMvNormal (
@@ -1351,7 +1350,7 @@ def _sample_unconditional(
13511350 self ._insert_random_variables ()
13521351
13531352 for name in self .data_names :
1354- pm .Data (** self ._exog_data_info [name ])
1353+ pm .Data (** self ._fit_exog_data [name ])
13551354
13561355 self ._insert_data_variables ()
13571356
@@ -1651,7 +1650,7 @@ def sample_statespace_matrices(
16511650 self ._insert_random_variables ()
16521651
16531652 for name in self .data_names :
1654- pm .Data (** self ._exog_data_info [name ])
1653+ pm .Data (** self .data_info [name ])
16551654
16561655 self ._insert_data_variables ()
16571656 matrices = self .unpack_statespace ()
@@ -1703,7 +1702,7 @@ def sample_filter_outputs(
17031702
17041703 if self .data_names :
17051704 for name in self .data_names :
1706- pm .Data (** self ._exog_data_info [name ])
1705+ pm .Data (** self ._fit_exog_data [name ])
17071706
17081707 self ._insert_data_variables ()
17091708
@@ -1846,7 +1845,7 @@ def _validate_scenario_data(
18461845 }
18471846
18481847 if self ._needs_exog_data and scenario is None :
1849- exog_str = "," .join (self ._exog_names )
1848+ exog_str = "," .join (self .data_names )
18501849 suffix = "s" if len (exog_str ) > 1 else ""
18511850 raise ValueError (
18521851 f"This model was fit using exogenous data. Forecasting cannot be performed without "
@@ -1855,7 +1854,7 @@ def _validate_scenario_data(
18551854
18561855 if isinstance (scenario , dict ):
18571856 for name , data in scenario .items ():
1858- if name not in self ._exog_names :
1857+ if name not in self .data_names :
18591858 raise ValueError (
18601859 f"Scenario data provided for variable '{ name } ', which is not an exogenous variable "
18611860 f"used to fit the model."
@@ -1896,12 +1895,12 @@ def _validate_scenario_data(
18961895 # name should only be None on the first non-recursive call. We only arrive to this branch in that case
18971896 # if a non-dictionary was passed, which in turn should only happen if only a single exogenous data
18981897 # needs to be set.
1899- if len (self ._exog_names ) > 1 :
1898+ if len (self .data_names ) > 1 :
19001899 raise ValueError (
19011900 "Multiple exogenous variables were used to fit the model. Provide a dictionary of "
19021901 "scenario data instead."
19031902 )
1904- name = self ._exog_names [0 ]
1903+ name = self .data_names [0 ]
19051904
19061905 # Omit dataframe from this basic shape check so we can give more detailed information about missing columns
19071906 # in the next check
@@ -2103,7 +2102,7 @@ def _finalize_scenario_initialization(
21032102 return scenario
21042103
21052104 # This was already checked as valid
2106- name = self ._exog_names [0 ] if name is None else name
2105+ name = self .data_names [0 ] if name is None else name
21072106
21082107 # Small tidying up in the case we just have a single scenario that's already a dataframe.
21092108 if isinstance (scenario , pd .DataFrame | pd .Series ):
0 commit comments