@@ -284,6 +284,8 @@ def _print_data_requirements(self) -> None:
284284 Prints a short report to the terminal about the data needed for the model, including their names, shapes,
285285 and named dimensions.
286286 """
287+ if not self .data_info :
288+ return
287289
288290 out = ""
289291 for data , info in self .data_info .items ():
@@ -618,63 +620,6 @@ def _get_matrix_shape_and_dims(
618620
619621 return shape , dims
620622
621- def _get_output_shape_and_dims (
622- self , idata : InferenceData , filter_output : str
623- ) -> tuple [
624- Optional [tuple [int ]], Optional [tuple [int ]], Optional [tuple [str ]], Optional [tuple [str ]]
625- ]:
626- """
627- Get the shapes and dimensions of the output variables from the provided InferenceData.
628-
629- This method extracts the shapes and dimensions of the output variables representing the state estimates
630- and covariances from the provided ArviZ InferenceData object. The state estimates are obtained from the
631- specified `filter_output` mode, which can be one of "filtered", "predicted", or "smoothed".
632-
633- Parameters
634- ----------
635- idata : arviz.InferenceData
636- The ArviZ InferenceData object containing the posterior samples.
637-
638- filter_output : str
639- The name of the filter output whose shape is being checked. It can be one of "filtered",
640- "predicted", or "smoothed".
641-
642- Returns
643- -------
644- mu_shape: tuple(int, int) or None
645- Shape of the mean vectors returned by the Kalman filter. Should be (n_data_timesteps, k_states).
646- If named dimensions are found, this will be None.
647-
648- cov_shape: tuple (int, int, int) or None
649- Shape of the hidden state covariance matrices returned by the Kalman filter. Should be
650- (n_data_timesteps, k_states, k_states).
651- If named dimensions are found, this will be None.
652-
653- mu_dims: tuple(str, str) or None
654- *Default* named dimensions associated with the mean vectors returned by the Kalman filter, or None if
655- the default names are not found.
656-
657- cov_dims: tuple (str, str, str) or None
658- *Default* named dimensions associated with the covariance matrices returned by the Kalman filter, or None if
659- the default names are not found.
660- """
661-
662- mu_dims = None
663- cov_dims = None
664-
665- mu_shape = idata [f"{ filter_output } _state" ].values .shape [2 :]
666- cov_shape = idata [f"{ filter_output } _covariance" ].values .shape [2 :]
667-
668- if all ([dim in self ._fit_coords for dim in [TIME_DIM , ALL_STATE_DIM , ALL_STATE_AUX_DIM ]]):
669- time_dim = TIME_DIM
670- mu_dims = [time_dim , ALL_STATE_DIM ]
671- cov_dims = [time_dim , ALL_STATE_DIM , ALL_STATE_AUX_DIM ]
672-
673- mu_shape = None
674- cov_shape = None
675-
676- return mu_shape , cov_shape , mu_dims , cov_dims
677-
678623 def _insert_random_variables (self ):
679624 """
680625 Replace pytensor symbolic variables with PyMC random variables.
@@ -1506,11 +1451,11 @@ def forecast(
15061451 "Scenario-based forcasting with exogenous variables not currently supported"
15071452 )
15081453
1509- dims = None
15101454 temp_coords = self ._fit_coords .copy ()
15111455
15121456 filter_time_dim = TIME_DIM
15131457
1458+ dims = None
15141459 if all ([dim in temp_coords for dim in [filter_time_dim , ALL_STATE_DIM , OBS_STATE_DIM ]]):
15151460 dims = [TIME_DIM , ALL_STATE_DIM , OBS_STATE_DIM ]
15161461
@@ -1544,14 +1489,10 @@ def forecast(
15441489 temp_coords ["data_time" ] = time_index
15451490 temp_coords [TIME_DIM ] = forecast_index
15461491
1547- mu_shape , cov_shape , mu_dims , cov_dims = self ._get_output_shape_and_dims (
1548- idata .posterior , filter_output
1549- )
1550-
1551- if mu_dims is not None :
1552- mu_dims = ["data_time" ] + mu_dims [1 :]
1553- if cov_dims is not None :
1554- cov_dims = ["data_time" ] + cov_dims [1 :]
1492+ mu_dims , cov_dims = None , None
1493+ if all ([dim in self ._fit_coords for dim in [TIME_DIM , ALL_STATE_DIM , ALL_STATE_AUX_DIM ]]):
1494+ mu_dims = ["data_time" , ALL_STATE_DIM ]
1495+ cov_dims = ["data_time" , ALL_STATE_DIM , ALL_STATE_AUX_DIM ]
15551496
15561497 with pm .Model (coords = temp_coords ):
15571498 [
0 commit comments